CANDIDATE NUMBERS: 39756, 37248, 45812, 44444, 41987
1. Setup¶
In [ ]:
!pip install opencv-python-headless
!pip install tqdm
!pip install pytorch-msssim lpips
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset, random_split
from tqdm import tqdm
import glob
# img processing
import cv2
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
from torchvision.transforms import Compose, Resize, ToTensor
# pckgs for results librarys
from skimage.metrics import peak_signal_noise_ratio as psnr_metric
from skimage.metrics import structural_similarity as ssim_metric
import pandas as pd
import time
import matplotlib.pyplot as plt
import lpips
from pytorch_msssim import ssim
import os
from google.colab import drive
from google.colab import auth
import shutil
import math
import traceback
try:
from pytorch_msssim import ssim as ssim_pytorch, ms_ssim
from skimage.metrics import peak_signal_noise_ratio as psnr_metric_skimage
from skimage.metrics import structural_similarity as ssim_metric_skimage
except ImportError as e:
print(f"Error importing metric libraries: {e}")
from torch.optim.lr_scheduler import ReduceLROnPlateau
Requirement already satisfied: opencv-python-headless in /usr/local/lib/python3.11/dist-packages (4.11.0.86) Requirement already satisfied: numpy>=1.21.2 in /usr/local/lib/python3.11/dist-packages (from opencv-python-headless) (2.0.2) Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (4.67.1) Collecting pytorch-msssim Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB) Collecting lpips Downloading lpips-0.1.4-py3-none-any.whl.metadata (10 kB) Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (from pytorch-msssim) (2.6.0+cu124) Requirement already satisfied: torchvision>=0.2.1 in /usr/local/lib/python3.11/dist-packages (from lpips) (0.21.0+cu124) Requirement already satisfied: numpy>=1.14.3 in /usr/local/lib/python3.11/dist-packages (from lpips) (2.0.2) Requirement already satisfied: scipy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from lpips) (1.15.2) Requirement already satisfied: tqdm>=4.28.1 in /usr/local/lib/python3.11/dist-packages (from lpips) (4.67.1) Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch->pytorch-msssim) (3.18.0) Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch->pytorch-msssim) (4.13.2) Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch->pytorch-msssim) (3.4.2) Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch->pytorch-msssim) (3.1.6) Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch->pytorch-msssim) (2025.3.2) Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->pytorch-msssim) Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB) Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->pytorch-msssim) Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB) Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->pytorch-msssim) Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB) Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->pytorch-msssim) Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB) Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->pytorch-msssim) Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB) Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->pytorch-msssim) Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB) Collecting nvidia-curand-cu12==10.3.5.147 (from torch->pytorch-msssim) Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB) Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch->pytorch-msssim) Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB) Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch->pytorch-msssim) Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB) Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch->pytorch-msssim) (0.6.2) Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch->pytorch-msssim) (2.21.5) Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch->pytorch-msssim) (12.4.127) Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch->pytorch-msssim) Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB) Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch->pytorch-msssim) (3.2.0) Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch->pytorch-msssim) (1.13.1) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch->pytorch-msssim) (1.3.0) Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision>=0.2.1->lpips) (11.2.1) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch->pytorch-msssim) (3.0.2) Downloading pytorch_msssim-1.0.0-py3-none-any.whl (7.7 kB) Downloading lpips-0.1.4-py3-none-any.whl (53 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 53.8/53.8 kB 3.3 MB/s eta 0:00:00 Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 363.4/363.4 MB 2.0 MB/s eta 0:00:00 Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.8/13.8 MB 122.7 MB/s eta 0:00:00 Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.6/24.6 MB 91.9 MB/s eta 0:00:00 Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 883.7/883.7 kB 57.2 MB/s eta 0:00:00 Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 2.8 MB/s eta 0:00:00 Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.5/211.5 MB 6.2 MB/s eta 0:00:00 Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 MB 13.6 MB/s eta 0:00:00 Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 127.9/127.9 MB 7.6 MB/s eta 0:00:00 Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.5/207.5 MB 5.7 MB/s eta 0:00:00 Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.1/21.1 MB 74.7 MB/s eta 0:00:00 Installing collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, pytorch-msssim, lpips Attempting uninstall: nvidia-nvjitlink-cu12 Found existing installation: nvidia-nvjitlink-cu12 12.5.82 Uninstalling nvidia-nvjitlink-cu12-12.5.82: Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82 Attempting uninstall: nvidia-curand-cu12 Found existing installation: nvidia-curand-cu12 10.3.6.82 Uninstalling nvidia-curand-cu12-10.3.6.82: Successfully uninstalled nvidia-curand-cu12-10.3.6.82 Attempting uninstall: nvidia-cufft-cu12 Found existing installation: nvidia-cufft-cu12 11.2.3.61 Uninstalling nvidia-cufft-cu12-11.2.3.61: Successfully uninstalled nvidia-cufft-cu12-11.2.3.61 Attempting uninstall: nvidia-cuda-runtime-cu12 Found existing installation: nvidia-cuda-runtime-cu12 12.5.82 Uninstalling nvidia-cuda-runtime-cu12-12.5.82: Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82 Attempting uninstall: nvidia-cuda-nvrtc-cu12 Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82 Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82: Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82 Attempting uninstall: nvidia-cuda-cupti-cu12 Found existing installation: nvidia-cuda-cupti-cu12 12.5.82 Uninstalling nvidia-cuda-cupti-cu12-12.5.82: Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82 Attempting uninstall: nvidia-cublas-cu12 Found existing installation: nvidia-cublas-cu12 12.5.3.2 Uninstalling nvidia-cublas-cu12-12.5.3.2: Successfully uninstalled nvidia-cublas-cu12-12.5.3.2 Attempting uninstall: nvidia-cusparse-cu12 Found existing installation: nvidia-cusparse-cu12 12.5.1.3 Uninstalling nvidia-cusparse-cu12-12.5.1.3: Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3 Attempting uninstall: nvidia-cudnn-cu12 Found existing installation: nvidia-cudnn-cu12 9.3.0.75 Uninstalling nvidia-cudnn-cu12-9.3.0.75: Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75 Attempting uninstall: nvidia-cusolver-cu12 Found existing installation: nvidia-cusolver-cu12 11.6.3.83 Uninstalling nvidia-cusolver-cu12-11.6.3.83: Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83 Successfully installed lpips-0.1.4 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 pytorch-msssim-1.0.0
2. Importing & Loading the Dataset¶
In [ ]:
auth.authenticate_user() #
drive.mount('/content/drive', force_remount=True)
os.makedirs("/content/ebb_dataset", exist_ok=True)
Mounted at /content/drive
In [ ]:
!wget "http://data.vision.ee.ethz.ch/ihnatova/public/ebb/Bokeh_Simulation_Dataset.zip" -O ebb_dataset/EBB_dataset.zip
!unzip -oq ebb_dataset/EBB_dataset.zip -d ebb_dataset/
base_path = f'/content/ebb_dataset/'
--2025-05-05 20:24:33-- http://data.vision.ee.ethz.ch/ihnatova/public/ebb/Bokeh_Simulation_Dataset.zip Resolving data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)... 129.132.52.178, 2001:67c:10ec:36c2::178 Connecting to data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)|129.132.52.178|:80... connected. HTTP request sent, awaiting response... 302 Found Location: https://data.vision.ee.ethz.ch/ihnatova/public/ebb/Bokeh_Simulation_Dataset.zip [following] --2025-05-05 20:24:34-- https://data.vision.ee.ethz.ch/ihnatova/public/ebb/Bokeh_Simulation_Dataset.zip Connecting to data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)|129.132.52.178|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 3791786922 (3.5G) [application/zip] Saving to: ‘ebb_dataset/EBB_dataset.zip’ ebb_dataset/EBB_dat 100%[===================>] 3.53G 22.1MB/s in 2m 45s 2025-05-05 20:27:20 (21.9 MB/s) - ‘ebb_dataset/EBB_dataset.zip’ saved [3791786922/3791786922]
In [ ]:
# fixed incorrect naming of test data as the validation set
!mv /content/ebb_dataset/validation /content/ebb_dataset/test
!echo "Renamed 'validation' directory to 'test'."
!ls /content/ebb_dataset/
Renamed 'validation' directory to 'test'. EBB_dataset.zip README.txt test train
In [ ]:
# @title Perform validation split on original training data
BASE_PATH = '/content/ebb_dataset/'
ORIGINAL_TRAIN_FOLDER = os.path.join(BASE_PATH, 'train')
NEW_VALIDATION_FOLDER = os.path.join(BASE_PATH, 'validation') # new folder
VALIDATION_SPLIT_RATIO = 0.20
SPLIT_SEED = 42
print(f"Source Train Folder: {ORIGINAL_TRAIN_FOLDER}")
print(f"Target Validation Folder: {NEW_VALIDATION_FOLDER}")
print(f"Validation Split Ratio: {VALIDATION_SPLIT_RATIO:.0%}")
# checks below:
if not os.path.isdir(ORIGINAL_TRAIN_FOLDER):
print(f"ERROR: Original training folder '{ORIGINAL_TRAIN_FOLDER}' not found. Cannot split.")
elif os.path.isdir(NEW_VALIDATION_FOLDER):
print(f"ERROR: Target validation folder '{NEW_VALIDATION_FOLDER}' already exists.")
print(" Please remove or rename it if you want to re-run the split.")
print(" Skipping file moving operation.")
else:
original_subdir = os.path.join(ORIGINAL_TRAIN_FOLDER, 'original')
bokeh_subdir = os.path.join(ORIGINAL_TRAIN_FOLDER, 'bokeh')
if not os.path.isdir(original_subdir) or not os.path.isdir(bokeh_subdir):
print(f"ERROR: Missing 'original' or 'bokeh' subdirectories inside {ORIGINAL_TRAIN_FOLDER}.")
else:
print("Finding image pairs in original training folder...")
try:
original_files = set(f for f in os.listdir(original_subdir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG')))
bokeh_files = set(f for f in os.listdir(bokeh_subdir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG')))
except FileNotFoundError:
print("Error listing files during pair finding.")
original_files, bokeh_files = set(), set()
image_pair_names = sorted(list(original_files.intersection(bokeh_files)))
num_total_pairs = len(image_pair_names)
if num_total_pairs == 0:
print("ERROR: No matching image pairs found in the original training folder. Cannot split.")
else:
print(f"Found {num_total_pairs} image pairs.")
# calc split sizes
num_val_pairs = int(math.floor(VALIDATION_SPLIT_RATIO * num_total_pairs))
num_train_pairs_new = num_total_pairs - num_val_pairs
print(f"Calculated split: {num_train_pairs_new} pairs for new Train, {num_val_pairs} pairs for new Validation.")
# shuffling and selecting
np.random.seed(SPLIT_SEED)
indices = list(range(num_total_pairs))
np.random.shuffle(indices)
validation_indices = indices[:num_val_pairs]
validation_files_to_move = [image_pair_names[i] for i in validation_indices]
print(f"Selected {len(validation_files_to_move)} files for validation set.")
# new val directories
new_val_original_dir = os.path.join(NEW_VALIDATION_FOLDER, 'original')
new_val_bokeh_dir = os.path.join(NEW_VALIDATION_FOLDER, 'bokeh')
try:
os.makedirs(new_val_original_dir, exist_ok=True)
os.makedirs(new_val_bokeh_dir, exist_ok=True)
print(f"Created new validation directories:")
print(f" - {new_val_original_dir}")
print(f" - {new_val_bokeh_dir}")
except OSError as e:
print(f"ERROR creating validation directories: {e}")
validation_files_to_move = [] # Prevent moving if dirs failed
# move files
if validation_files_to_move:
print(f"\nMoving {len(validation_files_to_move)} pairs to validation folder...")
moved_count = 0
error_count = 0
for filename in tqdm(validation_files_to_move, desc="Moving Files"):
src_orig_path = os.path.join(original_subdir, filename)
src_bokeh_path = os.path.join(bokeh_subdir, filename)
dst_orig_path = os.path.join(new_val_original_dir, filename)
dst_bokeh_path = os.path.join(new_val_bokeh_dir, filename)
try:
if os.path.exists(src_orig_path) and os.path.exists(src_bokeh_path):
shutil.move(src_orig_path, dst_orig_path)
shutil.move(src_bokeh_path, dst_bokeh_path)
moved_count += 1
else:
print(f"Warning: Source file missing, cannot move {filename}")
error_count += 1
except Exception as e:
print(f"ERROR moving file {filename}: {e}")
error_count += 1
print(f"\nFile Moving Complete:")
print(f" Successfully moved {moved_count} pairs.")
if error_count > 0:
print(f" Encountered errors for {error_count} pairs.")
# final checkss
print("\nVerifying final counts...")
final_train_orig_count = len([f for f in os.listdir(original_subdir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG'))])
final_val_orig_count = len([f for f in os.listdir(new_val_original_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG'))])
print(f" Files remaining in '{original_subdir}': {final_train_orig_count} (Expected: {num_train_pairs_new})")
print(f" Files moved to '{new_val_original_dir}': {final_val_orig_count} (Expected: {num_val_pairs})")
else:
print("No files selected for moving (possibly due to directory creation error).")
Source Train Folder: /content/ebb_dataset/train Target Validation Folder: /content/ebb_dataset/validation Validation Split Ratio: 20% Finding image pairs in original training folder... Found 4694 image pairs. Calculated split: 3756 pairs for new Train, 938 pairs for new Validation. Selected 938 files for validation set. Created new validation directories: - /content/ebb_dataset/validation/original - /content/ebb_dataset/validation/bokeh Moving 938 pairs to validation folder...
Moving Files: 100%|██████████| 938/938 [00:00<00:00, 10305.09it/s]
File Moving Complete: Successfully moved 938 pairs. Verifying final counts... Files remaining in '/content/ebb_dataset/train/original': 3756 (Expected: 3756) Files moved to '/content/ebb_dataset/validation/original': 938 (Expected: 938)
3. Models¶
In [ ]:
class BokehDataset(Dataset):
def __init__(self, original_dir, bokeh_dir, transform=None):
self.original_dir = original_dir
self.bokeh_dir = bokeh_dir
self.image_names = sorted(os.listdir(original_dir))
self.transform = transform
def __len__(self):
return len(self.image_names)
def __getitem__(self, idx):
original_path = os.path.join(self.original_dir, self.image_names[idx])
bokeh_path = os.path.join(self.bokeh_dir, self.image_names[idx])
original_image = Image.open(original_path).convert('RGB')
bokeh_image = Image.open(bokeh_path).convert('RGB')
if self.transform:
original_image = self.transform(original_image)
bokeh_image = self.transform(bokeh_image)
return original_image, bokeh_image
In [ ]:
# CNN 3 level construction
class bokeh_CNN(nn.Module):
def __init__(self, in_channels=3, out_channels=3, init_features=32):
super(bokeh_CNN, self).__init__()
features = init_features
# Input channels used here
self.enc1 = nn.Sequential(
nn.Conv2d(in_channels, features, kernel_size=3, padding=1, bias=False), # use in_channels
nn.BatchNorm2d(features), nn.ReLU(True),
nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(features), nn.ReLU(True)
)
self.pool1 = nn.MaxPool2d(2, 2)
# Encoder lvl2
self.enc2 = nn.Sequential(
nn.Conv2d(features, features * 2, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(features * 2), nn.ReLU(True),
nn.Conv2d(features * 2, features * 2, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(features * 2), nn.ReLU(True)
)
self.pool2 = nn.MaxPool2d(2, 2)
# Encoder LVL 3
self.enc3 = nn.Sequential(
nn.Conv2d(features * 2, features * 4, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(features * 4), nn.ReLU(True),
nn.Conv2d(features * 4, features * 4, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(features * 4), nn.ReLU(True)
)
# Decoder levl 2
self.up2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2)
self.dec2 = nn.Sequential(
# Input channels = upsampled channels + skip connection channels
nn.Conv2d(features * 4, features * 2, kernel_size=3, padding=1, bias=False), # features*2 (from up2) + features*2 (from enc2)
nn.BatchNorm2d(features * 2), nn.ReLU(True),
nn.Conv2d(features * 2, features * 2, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(features * 2), nn.ReLU(True)
)
# Decoder level 1
self.up1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)
self.dec1 = nn.Sequential(
# input channels = upsampled channels + skip connection channels
nn.Conv2d(features * 2, features, kernel_size=3, padding=1, bias=False), # features (from up1) + features (from enc1)
nn.BatchNorm2d(features), nn.ReLU(True),
nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(features), nn.ReLU(True)
)
# final output conv
self.output_conv = nn.Conv2d(features, out_channels, kernel_size=1) # bias is okay in final layer
def forward(self, x):
# enc path
x1 = self.enc1(x)
p1 = self.pool1(x1)
x2 = self.enc2(p1)
p2 = self.pool2(x2)
x3 = self.enc3(p2) # Bottleneck
# Decoder path with skip connections
u2 = self.up2(x3)
# handle pot. size mismatches (less likely though as using padding=1 and stride=2)
diffY = x2.size()[2] - u2.size()[2]
diffX = x2.size()[3] - u2.size()[3]
u2 = F.pad(u2, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
cat2 = torch.cat([u2, x2], dim=1) # concatenate along channel dimension
x4 = self.dec2(cat2)
u1 = self.up1(x4)
diffY = x1.size()[2] - u1.size()[2]
diffX = x1.size()[3] - u1.size()[3]
u1 = F.pad(u1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
cat1 = torch.cat([u1, x1], dim=1)
x5 = self.dec1(cat1)
out = self.output_conv(x5)
return torch.sigmoid(out)
model_instance = bokeh_CNN(in_channels=4, out_channels=3, init_features=32)
total_params = sum(p.numel() for p in model_instance.parameters())
trainable_params = sum(p.numel() for p in model_instance.parameters() if p.requires_grad)
print(f"Input Channels: {model_instance.enc1[0].in_channels}")
print(f"Output Channels: {model_instance.output_conv.out_channels}")
print(f"Initial Features: {model_instance.enc1[0].out_channels}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print("-----------------------------------------------------\n")
Input Channels: 4 Output Channels: 3 Initial Features: 32 Total parameters: 467,523 Trainable parameters: 467,523 -----------------------------------------------------
4. Multi weighted Loss function - inspired by PyNet except all combined¶
In [ ]:
def gaussian_pyramid(img: torch.Tensor, max_levels: int = 3) -> list:
# separable 5-tap Gaussian kernel
k1d = torch.tensor([1.,4.,6.,4.,1.], device=img.device) / 16.
ker = (k1d[:,None] * k1d[None,:]).unsqueeze(0).unsqueeze(0)
ker = ker.repeat(img.size(1),1,1,1)
pyr = [img]
cur = img
for _ in range(max_levels):
blurred = F.conv2d(cur, ker, padding=2, groups=cur.size(1))
cur = blurred[:, :, ::2, ::2]
pyr.append(cur)
return pyr
class WeightedPyNetPerceptualLoss(nn.Module):
def __init__(self,
fg_weight: float = 2.0,
bg_weight: float = 1.0,
max_levels: int = 3,
level_weights: list = None,
lpips_weight: float = 1.0,
ssim_weight: float = 1.0):
super().__init__()
self.fg_w, self.bg_w = fg_weight, bg_weight
self.max_levels = max_levels
self.level_weights = level_weights or [1.0/(2**i) for i in range(max_levels+1)]
self.lpips_fn = lpips.LPIPS(net='alex').eval()
self.lpips_weight = lpips_weight
self.ssim_weight = ssim_weight
# laplacian kernel for focus mask
lap = torch.tensor([[0.,1.,0.],[1.,-4.,1.],[0.,1.,0.]])
self.register_buffer('lap_kernel', lap.unsqueeze(0).unsqueeze(0))
def generate_focus_mask(self, img: torch.Tensor, threshold: float = 0.03) -> torch.Tensor:
gray = img.mean(dim=1, keepdim=True)
lap = F.conv2d(gray, self.lap_kernel, padding=1)
return (lap.abs() > threshold).float()
def forward(self,
pred: torch.Tensor,
target: torch.Tensor,
original: torch.Tensor) -> (torch.Tensor, dict):
mask_full = self.generate_focus_mask(original).to(pred.device)
pred_pyr = gaussian_pyramid(pred, self.max_levels)
tgt_pyr = gaussian_pyramid(target, self.max_levels)
l1_loss, ssim_loss = 0.0, 0.0
for lvl, (p, t) in enumerate(zip(pred_pyr, tgt_pyr)):
w_lvl = self.level_weights[lvl]
mask_lvl = F.interpolate(mask_full, size=p.shape[2:], mode='nearest')
pixel_w = mask_lvl * self.fg_w + (1 - mask_lvl) * self.bg_w
l1 = (p - t).abs()
l1_loss += w_lvl * (l1 * pixel_w).mean()
ssim_val = ssim(p, t, data_range=1.0)
avg_w = pixel_w.mean()
ssim_loss += w_lvl * (1 - ssim_val) * avg_w
lpips_loss = self.lpips_fn(pred, target).mean()
total = l1_loss + self.ssim_weight * ssim_loss + self.lpips_weight * lpips_loss
return total, {
'l1_multi': l1_loss,
'ssim_multi': ssim_loss,
'lpips': lpips_loss
}
5. Mini training loops to evaluate different architectures and hyperparameter combinations (before final loop)¶
In [ ]:
# @title Vanilla U-Net experimentation, tweak hyperparameters for Table 1 & 2 Results
torch.manual_seed(42) # for reproducibility
np.random.seed(42)
# paths
BASE_PATH = '/content/ebb_dataset/'
DRIVE_SAVE_DIR = "/content/drive/MyDrive/Bokeh_MiniLoop_Runs"
# Mini-Loop DATA SUBSETS - for speed and due to compute constraints
TRAIN_SUBSET_TARGET_SIZE = 1024
VAL_SUBSET_TARGET_SIZE = 180
IMG_SIZE = 512
EVAL_CROP_BORDER = 32
EVAL_SSIM_WIN_SIZE = 11
GRAD_CLIP_VALUE = 1.0
os.makedirs(DRIVE_SAVE_DIR, exist_ok=True)
print(f"Mini-loop models/checkpoints will be saved in: {DRIVE_SAVE_DIR}")
#setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"--- Running on device: {device} ---")
# Single Transform, no augm.
transform_no_aug = Compose([Resize((IMG_SIZE, IMG_SIZE)), ToTensor()])
print("Defined unified transform (No Augmentation).")
# dataset & loaders
train_loader_mini = None; val_loader_mini = None; test_loader = None
try:
if 'BASE_PATH' not in globals() or not os.path.isdir(BASE_PATH): raise ValueError("'BASE_PATH' not defined")
original_train_dir=os.path.join(BASE_PATH, 'train/original'); bokeh_train_dir=os.path.join(BASE_PATH, 'train/bokeh'); new_val_original_dir=os.path.join(BASE_PATH, 'validation/original'); new_val_bokeh_dir=os.path.join(BASE_PATH, 'validation/bokeh'); test_original_dir=os.path.join(BASE_PATH, 'test/original'); test_bokeh_dir=os.path.join(BASE_PATH, 'test/bokeh')
print("Loading full original training dataset index for splitting..."); full_new_train_dataset = BokehDataset(original_dir=original_train_dir, bokeh_dir=bokeh_train_dir, transform=transform_no_aug); num_actual_train_samples = len(full_new_train_dataset); print(f"Found {num_actual_train_samples} pairs in 'train'."); assert num_actual_train_samples > 0
print("Loading index for the new 'validation' dataset..."); full_new_val_dataset = BokehDataset(original_dir=new_val_original_dir, bokeh_dir=new_val_bokeh_dir, transform=transform_no_aug); num_actual_val_samples = len(full_new_val_dataset); print(f"Found {num_actual_val_samples} pairs in 'validation'."); assert num_actual_val_samples > 0
train_subset_size = min(TRAIN_SUBSET_TARGET_SIZE, num_actual_train_samples); val_subset_size = min(VAL_SUBSET_TARGET_SIZE, num_actual_val_samples)
train_subset_loop, _ = random_split(full_new_train_dataset, [train_subset_size, num_actual_train_samples - train_subset_size], generator=torch.Generator().manual_seed(42))
val_subset_loop, _ = random_split(full_new_val_dataset, [val_subset_size, num_actual_val_samples - val_subset_size], generator=torch.Generator().manual_seed(42))
print(f"\nMini-loop random subsets created: {len(train_subset_loop)} train / {len(val_subset_loop)} validation.")
def collate_fn_skip_none(batch): batch=list(filter(lambda x: x is not None and x[0] is not None and x[1] is not None, batch)); return torch.utils.data.dataloader.default_collate(batch) if batch else None
num_workers = 2 # BATCH_SIZE set in EXP block
# Loaders configured in EXP block
if os.path.isdir(test_original_dir) and os.path.isdir(test_bokeh_dir):
test_dataset = BokehDataset(original_dir=test_original_dir, bokeh_dir=test_bokeh_dir, transform=transform_no_aug)
if len(test_dataset) > 0: print(f"Test dataset loaded ({len(test_dataset)} samples).")
else: print("Test dataset directory found but is empty.")
else: print("Test dataset directory not found.")
except Exception as e: print(f"Error setting up Datasets/DataLoaders: {e}"); raise
# evaluation function definition ->
def evaluate_model(model, dataloader, lpips_eval_fn, crop_border=0, ssim_win_size=11):
model.eval(); eval_device = next(model.parameters()).device; total_psnr, total_ssim_sk, total_msssim, total_lpips = 0.0, 0.0, 0.0, 0.0; lpips_msssim_count, psnr_ssim_count = 0, 0
if dataloader is None: return 0,0,0,0
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating", leave=False, disable=len(dataloader)<5):
if batch is None: continue
try: inputs, targets = batch; assert isinstance(inputs, torch.Tensor) and isinstance(targets, torch.Tensor)
except (AssertionError, TypeError, ValueError): continue
if inputs.shape[1] != 3 or targets.shape[1] != 3: continue
inputs, targets = inputs.to(eval_device), targets.to(eval_device); outputs = model(inputs)
outputs_clipped = torch.clamp(outputs, 0.0, 1.0); current_batch_size = outputs_clipped.shape[0]; outputs_lpips_input = outputs_clipped * 2.0 - 1.0; targets_lpips_input = targets * 2.0 - 1.0
try: lpips_val = lpips_eval_fn(outputs_lpips_input, targets_lpips_input.detach()); total_lpips += lpips_val.sum().item(); lpips_msssim_count += current_batch_size
except Exception: pass
try: ms_ssim_val_batch = ms_ssim(outputs_clipped, targets.detach(), data_range=1.0, size_average=False); total_msssim += ms_ssim_val_batch.sum().item()
except Exception: pass
outputs_np_full = outputs_clipped.cpu().numpy(); targets_np_full = targets.cpu().numpy()
for i in range(current_batch_size):
out_img_full_np = np.transpose(outputs_np_full[i], (1, 2, 0)); tgt_img_full_np = np.transpose(targets_np_full[i], (1, 2, 0)); h, w = out_img_full_np.shape[:2]
if crop_border > 0 and h > 2 * crop_border and w > 2 * crop_border: out_img_eval_np=out_img_full_np[crop_border:-crop_border, crop_border:-crop_border, :]; tgt_img_eval_np=tgt_img_full_np[crop_border:-crop_border, crop_border:-crop_border, :]
else: out_img_eval_np=out_img_full_np; tgt_img_eval_np=tgt_img_full_np
try:
psnr=psnr_metric_skimage(tgt_img_eval_np, out_img_eval_np, data_range=1.0); ch_sk, cw_sk = tgt_img_eval_np.shape[:2]; current_win_size=min(ssim_win_size, ch_sk, cw_sk)
if current_win_size % 2 == 0: current_win_size -= 1; current_win_size=max(3, current_win_size)
if ch_sk >= current_win_size and cw_sk >= current_win_size : ssim_val_sk=ssim_metric_skimage(tgt_img_eval_np, out_img_eval_np, channel_axis=-1, data_range=1.0, win_size=current_win_size, gaussian_weights=True, multichannel=True)
else: ssim_val_sk = np.nan
if not np.isnan(psnr) and not np.isnan(ssim_val_sk): total_psnr+=psnr; total_ssim_sk+=ssim_val_sk; psnr_ssim_count+=1
except ValueError: pass
avg_psnr=total_psnr/psnr_ssim_count if psnr_ssim_count>0 else 0; avg_ssim_sk=total_ssim_sk/psnr_ssim_count if psnr_ssim_count > 0 else 0; avg_msssim=total_msssim/lpips_msssim_count if lpips_msssim_count > 0 else 0; avg_lpips=total_lpips/lpips_msssim_count if lpips_msssim_count > 0 else 0
return avg_psnr, avg_ssim_sk, avg_msssim, avg_lpips
# training loop function
def run_mini_training_experiment(model_name, model, train_loader, val_loader, criterion, lpips_eval_fn, optimizer, epochs, device):
print(f"--- Starting Mini-Loop Training: {model_name} ---"); model.to(device)
history = {'train_loss':[], 'val_loss':[], 'val_l1_loss':[], 'val_psnr':[], 'val_ssim_sk':[], 'val_msssim':[], 'val_lpips':[]}
start_run_time = time.time()
if train_loader is None or len(train_loader)==0: print("ERROR: Train loader empty."); return history, *([float('nan')]*7), 0.0
for epoch in range(epochs):
epoch_start_time = time.time(); model.train(); running_train_loss = 0.0
pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs} [MiniTrain]", leave=False)
for i, batch in pbar:
if batch is None: continue
try: inputs, targets = batch; assert isinstance(inputs, torch.Tensor) and isinstance(targets, torch.Tensor)
except (AssertionError, TypeError, ValueError): continue
if inputs.shape[1] != 3 or targets.shape[1] != 3: continue
try:
inputs, targets = inputs.to(device), targets.to(device); optimizer.zero_grad(); outputs = model(inputs); total_loss, _ = criterion(pred=outputs, target=targets, original=inputs)
if torch.isnan(total_loss) or torch.isinf(total_loss): continue
total_loss.backward();
if 'GRAD_CLIP_VALUE' in globals() and GRAD_CLIP_VALUE: torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_VALUE)
optimizer.step(); running_train_loss += total_loss.item(); pbar.set_postfix({'Loss': f'{total_loss.item():.4f}'})
except Exception as batch_err: print(f"Error processing train batch {i}: {batch_err}"); traceback.print_exc(); continue
avg_train_loss = running_train_loss / len(train_loader) if len(train_loader) > 0 else 0; history['train_loss'].append(avg_train_loss)
model.eval(); running_val_loss = 0.0; running_val_l1_loss = 0.0; avg_val_l1_loss = float('nan'); avg_psnr, avg_ssim_sk, avg_msssim, avg_lpips = 0,0,0,0
if val_loader is None or len(val_loader) == 0: print("Warning: Val loader empty."); avg_val_loss = float('nan')
else:
avg_psnr, avg_ssim_sk, avg_msssim, avg_lpips = evaluate_model(model, val_loader, lpips_eval_fn, EVAL_CROP_BORDER, EVAL_SSIM_WIN_SIZE)
with torch.no_grad():
for batch_val in tqdm(val_loader, desc=f"Epoch {epoch+1} [Validation]", leave=False): # Add tqdm
if batch_val is None: continue
try: inputs_val, targets_val = batch_val; assert isinstance(inputs_val, torch.Tensor) and isinstance(targets_val, torch.Tensor)
except (AssertionError, TypeError, ValueError): continue
if inputs_val.shape[1] != 3 or targets_val.shape[1] != 3: continue
try:
inputs_val, targets_val = inputs_val.to(device), targets_val.to(device); outputs_val = model(inputs_val)
val_total_loss, val_loss_components = criterion(pred=outputs_val, target=targets_val, original=inputs_val)
if torch.isnan(val_total_loss) or torch.isinf(val_total_loss): continue
running_val_loss += val_total_loss.item(); running_val_l1_loss += val_loss_components.get('l1_multi_weighted', 0.0)
except Exception as val_batch_err: print(f"Error processing validation batch: {val_batch_err}"); continue
avg_val_loss = running_val_loss / len(val_loader) if len(val_loader) > 0 else float('inf'); avg_val_l1_loss = running_val_l1_loss / len(val_loader) if len(val_loader) > 0 else float('nan')
history['val_loss'].append(avg_val_loss); history['val_l1_loss'].append(avg_val_l1_loss); history['val_psnr'].append(avg_psnr); history['val_ssim_sk'].append(avg_ssim_sk); history['val_msssim'].append(avg_msssim); history['val_lpips'].append(avg_lpips)
epoch_duration = time.time() - epoch_start_time
print(f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_train_loss:.4f} - Val Loss: {avg_val_loss:.4f} - Val L1: {avg_val_l1_loss:.4f} | PSNR: {avg_psnr:.4f}, SSIM: {avg_ssim_sk:.4f}, MS-SSIM: {avg_msssim:.4f}, LPIPS: {avg_lpips:.4f} - Time: {epoch_duration:.2f}s")
run_duration_mins = (time.time() - start_run_time) / 60.0; print(f"--- Finished Mini-Loop Training: {model_name} ({run_duration_mins:.2f} mins) ---")
final_train_loss = history['train_loss'][-1] if history.get('train_loss') else float('nan'); final_val_loss = history['val_loss'][-1] if history.get('val_loss') else float('nan'); final_l1_loss = history['val_l1_loss'][-1] if history.get('val_l1_loss') else float('nan'); final_psnr = history['val_psnr'][-1] if history.get('val_psnr') else float('nan'); final_ssim_sk = history['val_ssim_sk'][-1] if history.get('val_ssim_sk') else float('nan'); final_msssim = history['val_msssim'][-1] if history.get('val_msssim') else float('nan'); final_lpips = history['val_lpips'][-1] if history.get('val_lpips') else float('nan')
return history, final_train_loss, final_val_loss, final_l1_loss, final_psnr, final_ssim_sk, final_msssim, final_lpips, run_duration_mins
# --- MAIN MINI-LOOP EXPERIMENT EXECUTION BLOCK ---
EXP_ID = "Mini_VUNet_IF32_LW_S4.0_P0.5_B4_E10" # NEW ID
EXP_SSIM_WEIGHT = 4.0
EXP_LPIPS_WEIGHT = 0.5
EXP_BATCH_SIZE = 4
EXP_NUM_EPOCHS = 10
EXP_LEARNING_RATE = 2e-4
EXP_MODEL_CLASS = bokeh_CNN
EXP_MODEL_NAME = "Vanilla_UNet (bokeh_CNN)"
EXP_INIT_FEATURES = 32
EXP_INPUT_CHANNELS = 3
EXP_DEPTH_LEVELS = 3
MODEL_KWARGS = {}
# ---
print(f"\n\n{'='*30} STARTING MINI-LOOP EXPERIMENT: {EXP_ID} {'='*30}")
# instantiate
print(f"Instantiating model: {EXP_MODEL_NAME}..."); model_instance = None; params_m = float('nan')
try:
# N.B. with BatchNorm
model_instance = EXP_MODEL_CLASS(
in_channels=EXP_INPUT_CHANNELS,
out_channels=3,
init_features=EXP_INIT_FEATURES,
**MODEL_KWARGS
).to(device)
total_params = sum(p.numel() for p in model_instance.parameters()); params_m = total_params / 1_000_000
print(f"Model Instantiated: {params_m:.2f}M parameters")
except Exception as e: print(f"Model Instantiation Error: {e}"); raise
# config MINI DataLoaders
current_train_loader = None; current_val_loader = None
try:
if 'train_subset_loop' not in globals() or 'val_subset_loop' not in globals(): raise NameError("Mini-loop subset datasets not defined.")
current_train_loader = DataLoader(train_subset_loop, batch_size=EXP_BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn_skip_none)
current_val_loader = DataLoader(val_subset_loop, batch_size=EXP_BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn_skip_none)
print(f"Mini-loop DataLoaders configured with Batch Size: {EXP_BATCH_SIZE}"); assert len(current_train_loader)>0 and len(current_val_loader)>0
except Exception as e: print(f"Error configuring Mini-loop DataLoaders: {e}"); raise
# loss func.
criterion = None; lpips_eval_fn = None
try:
criterion = WeightedPyNetPerceptualLoss(lpips_weight=EXP_LPIPS_WEIGHT, ssim_weight=EXP_SSIM_WEIGHT).to(device)
lpips_eval_fn = lpips.LPIPS(net='alex', verbose=False).to(device)
print(f"Loss function instantiated with SSIM_w={EXP_SSIM_WEIGHT}, LPIPS_w={EXP_LPIPS_WEIGHT}")
except Exception as e: print(f"Loss Instantiation Error: {e}"); raise
print("\n--- Mini-Loop Experiment Configuration ---"); print(f"Exp ID: {EXP_ID}"); print(f"Model Arch: {EXP_MODEL_NAME}"); print(f"Input Channels: {EXP_INPUT_CHANNELS} (RGB)"); print(f"Width (IF): {EXP_INIT_FEATURES}"); print(f"Depth (Levels): {EXP_DEPTH_LEVELS}"); print(f"Learning Rate: {EXP_LEARNING_RATE}"); print(f"Batch Size: {EXP_BATCH_SIZE}"); print(f"Target Epochs: {EXP_NUM_EPOCHS}"); print(f"SSIM Weight: {EXP_SSIM_WEIGHT}"); print(f"LPIPS Weight: {EXP_LPIPS_WEIGHT}"); print(f"Parameters (M): {params_m:.2f}M"); print("--------------------------------------\n")
# optimiser ->
optimizer = optim.Adam(model_instance.parameters(), lr=EXP_LEARNING_RATE)
# train mini loop
history = {}; final_train_loss, final_val_loss, final_l1_loss = [float('nan')] * 3
final_psnr, final_ssim_sk, final_msssim, final_lpips = [float('nan')] * 4; run_duration_mins = float('nan')
try:
history, final_train_loss, final_val_loss, final_l1_loss, \
final_psnr, final_ssim_sk, final_msssim, final_lpips, \
run_duration_mins = run_mini_training_experiment(
model_name=f"MiniLoop_{EXP_ID}", model=model_instance, train_loader=current_train_loader, val_loader=current_val_loader,
criterion=criterion, lpips_eval_fn=lpips_eval_fn, optimizer=optimizer, epochs=EXP_NUM_EPOCHS, device=device
)
except Exception as main_loop_err: print(f"\n!!! ERROR during training execution for {EXP_ID}: {main_loop_err} !!!"); traceback.print_exc()
# save model
if not math.isnan(final_val_loss):
save_path = os.path.join(DRIVE_SAVE_DIR, f"{EXP_ID}_model.pth")
try: torch.save(model_instance.state_dict(), save_path); print(f"Mini-loop model weights saved to: {save_path}")
except Exception as e: print(f"Error saving mini-loop model: {e}")
# results
print("\n--- Mini-Loop Experiment Results Summary ---")
print("| Exp ID | DataAug | Width | Depth | Lr | B_s | Epochs | Params (M) | Runtime (Min) | Train_Loss | Val_Loss | Gap |")
print("|---------------------------------------------------|---------|-------|-------|--------|-----|--------|------------|---------------|------------|----------|--------|")
exp_id_str = f"{EXP_ID:<49}"; data_aug_str = "N"; width_str = f"{EXP_INIT_FEATURES:<5}"; depth_str = f"{EXP_DEPTH_LEVELS:<5}"
lr_str = f"{EXP_LEARNING_RATE:<6.1e}"; bs_str = f"{EXP_BATCH_SIZE:<3}"; epoch_str = f"{EXP_NUM_EPOCHS:<6}"
param_str = f"{params_m:<10.2f}"; runtime_str = f"{run_duration_mins:<13.2f}" if not math.isnan(run_duration_mins) else "N/A"
train_loss_str = f"{final_train_loss:<10.4f}" if not math.isnan(final_train_loss) else "N/A"; val_loss_str = f"{final_val_loss:<8.4f}" if not math.isnan(final_val_loss) else "N/A"
gap_val = float('nan');
if not math.isnan(final_val_loss) and not math.isnan(final_train_loss): gap_val = final_val_loss - final_train_loss
gap_str = f"{gap_val:<6.4f}" if not math.isnan(gap_val) else "N/A"
print(f"| {exp_id_str} | {data_aug_str:<7} | {width_str} | {depth_str} | {lr_str} | {bs_str} | {epoch_str} | {param_str}M | {runtime_str} | {train_loss_str} | {val_loss_str} | {gap_str} |")
print("-" * 145)
print(f"(Final Val Metrics: PSNR={final_psnr:.2f}, SSIM={final_ssim_sk:.4f}, MS-SSIM={final_msssim:.4f}, LPIPS={final_lpips:.4f}, Val L1={final_l1_loss:.4f})")
print(f"\n{'='*30} FINISHED MINI-LOOP EXPERIMENT: {EXP_ID} {'='*30}\n\n")
In [ ]:
# @title QUALITATIVE ANALYSIS OF LOSS WEIGHTS (Figure 1)
# reusing bokeh_CNN class definition here with batchnorm
class bokeh_CNN(nn.Module):
def __init__(self, in_channels=3, out_channels=3, init_features=32):
super(bokeh_CNN, self).__init__(); features = init_features
self.enc1=nn.Sequential(nn.Conv2d(in_channels, features, 3, padding=1, bias=False), nn.BatchNorm2d(features), nn.ReLU(True), nn.Conv2d(features, features, 3, padding=1, bias=False), nn.BatchNorm2d(features), nn.ReLU(True)); self.pool1=nn.MaxPool2d(2, 2)
self.enc2=nn.Sequential(nn.Conv2d(features, features * 2, 3, padding=1, bias=False), nn.BatchNorm2d(features * 2), nn.ReLU(True), nn.Conv2d(features * 2, features * 2, 3, padding=1, bias=False), nn.BatchNorm2d(features * 2), nn.ReLU(True)); self.pool2=nn.MaxPool2d(2, 2)
self.enc3=nn.Sequential(nn.Conv2d(features * 2, features * 4, 3, padding=1, bias=False), nn.BatchNorm2d(features * 4), nn.ReLU(True), nn.Conv2d(features * 4, features * 4, 3, padding=1, bias=False), nn.BatchNorm2d(features * 4), nn.ReLU(True)); self.up2=nn.ConvTranspose2d(features * 4, features * 2, 2, stride=2)
self.dec2=nn.Sequential(nn.Conv2d(features * 4, features * 2, 3, padding=1, bias=False), nn.BatchNorm2d(features * 2), nn.ReLU(True), nn.Conv2d(features * 2, features * 2, 3, padding=1, bias=False), nn.BatchNorm2d(features * 2), nn.ReLU(True)); self.up1=nn.ConvTranspose2d(features * 2, features, 2, stride=2)
self.dec1=nn.Sequential(nn.Conv2d(features * 2, features, 3, padding=1, bias=False), nn.BatchNorm2d(features), nn.ReLU(True), nn.Conv2d(features, features, 3, padding=1, bias=False), nn.BatchNorm2d(features), nn.ReLU(True)); self.output_conv=nn.Conv2d(features, out_channels, 1)
def forward(self, x):
try:
x1=self.enc1(x); p1=self.pool1(x1); x2=self.enc2(p1); p2=self.pool2(x2); x3=self.enc3(p2)
u2=self.up2(x3); diffY=x2.size(2)-u2.size(2); diffX=x2.size(3)-u2.size(3); u2=F.pad(u2,[diffX//2,diffX-diffX//2,diffY//2,diffY-diffY//2]); cat2=torch.cat([u2,x2],dim=1)
x4=self.dec2(cat2); u1=self.up1(x4); diffY=x1.size(2)-u1.size(2); diffX=x1.size(3)-u1.size(3); u1=F.pad(u1,[diffX//2,diffX-diffX//2,diffY//2,diffY-diffY//2]); cat1=torch.cat([u1,x1],dim=1)
x5=self.dec1(cat1); out=self.output_conv(x5); return torch.sigmoid(out)
except Exception as e: print(f"Error in bokeh_CNN forward pass: {e}"); traceback.print_exc(); raise
BASE_PATH = '/content/ebb_dataset/'
SAVE_DIR = "/content/drive/MyDrive/Bokeh_MiniLoop_Runs"
MODEL_CONFIG = {'in_channels': 3, 'out_channels': 3, 'init_features': 32}
IMG_SIZE = 512
# my chosen files with 'challenging' scenes which i wanted to do qualitative analysis on
SELECTED_FILENAMES = [
"1724.jpg",
"2086.jpg",
"2013.jpg",
"3591.jpg",
"3588.jpg",
]
MODELS_TO_LOAD = [
("Baseline", 1.0, 0.5, os.path.join(SAVE_DIR, "/content/drive/MyDrive/Bokeh_MiniLoop_Runs/Mini_VUNet_IF128_LW_S1.0_P0.5_B4_E10_model.pth")), # <<< FILENAME from Run 1
("LPIPS Weighted", 1.0, 1.0, os.path.join(SAVE_DIR, "Mini_VUNet_IF32_LW_S1.0_P1.0_B4_E10_model.pth")), # UPDATE FILENAME AS NEEDED
("LPIPS Heavy", 1.0, 2.0, os.path.join(SAVE_DIR, "Mini_VUNet_IF32_LW_S1.0_P2.0_B4_E10_model.pth")), # UPDATE FILENAME AS NEEDED
("SSIM Heavy", 4.0, 0.5, os.path.join(SAVE_DIR, "Mini_VUNet_IF32_LW_S4.0_P0.5_B4_E10_model.pth")), # UPDATE FILENAME AS NEEDED
]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
to_pil = ToPILImage()
transform_model_input = Compose([Resize((IMG_SIZE, IMG_SIZE)), ToTensor()])
resize_display = Resize((IMG_SIZE, IMG_SIZE))
print(f"\n--- Collecting data for {len(SELECTED_FILENAMES)} images and {len(MODELS_TO_LOAD)} models ---")
collected_data = []
comparison_original_dir = os.path.join(BASE_PATH, 'validation/original')
comparison_bokeh_dir = os.path.join(BASE_PATH, 'validation/bokeh')
print(f"Loading images for comparison from: {comparison_original_dir}")
# loading models
model_instances = {}
for label_prefix, s_w, p_w, model_path in MODELS_TO_LOAD:
label = f"{label_prefix} ($\\lambda_1$={s_w:.1f}, $\\lambda_2$={p_w:.1f})" # Lambda 1 = SSIM, Lambda 2 = LPIPS
print(f"Loading model weights for: {label}")
if not os.path.exists(model_path):
print(f" WARNING - Model file not found: {model_path}")
model_instances[label] = None; continue
try:
model = bokeh_CNN(**MODEL_CONFIG).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval(); model_instances[label] = model
except Exception as load_err: print(f" ERROR loading state dict for {label}: {load_err}"); model_instances[label] = None
# loop through th selected images from the validation set ---
for filename in tqdm(SELECTED_FILENAMES, desc="Processing Images"):
original_path = os.path.join(comparison_original_dir, filename)
target_path = os.path.join(comparison_bokeh_dir, filename)
image_data = {'filename': filename, 'input_pil': None, 'target_pil': None, 'outputs': {}}
if not os.path.exists(original_path) or not os.path.exists(target_path): print(f"Skipping {filename} - File not found."); collected_data.append(image_data); continue
try:
original_pil = Image.open(original_path).convert('RGB'); target_pil = Image.open(target_path).convert('RGB')
image_data['input_pil'] = resize_display(original_pil); image_data['target_pil'] = resize_display(target_pil) # store resized for display
input_tensor = transform_model_input(original_pil).unsqueeze(0).to(device)
for label, model in model_instances.items():
if model is None: image_data['outputs'][label] = None; continue
try:
with torch.no_grad(): output_tensor = model(input_tensor)
output_tensor_clamped = torch.clamp(output_tensor.squeeze(0).cpu(), 0.0, 1.0); output_pil = to_pil(output_tensor_clamped)
image_data['outputs'][label] = output_pil
except Exception as inference_err: print(f"ERROR during inference for {filename} with model {label}: {inference_err}"); image_data['outputs'][label] = None
collected_data.append(image_data)
except Exception as e: print(f"ERROR processing file {filename}: {e}"); traceback.print_exc(); collected_data.append(image_data) # Add placeholder on error
print("\n--- Data Collection Complete ---")
# plotting qualitiative plot
print("\n--- Generating Consolidated Plot ---")
num_images = len(collected_data); num_models = len(MODELS_TO_LOAD); num_cols = 2 + num_models
if num_images == 0: print("No data collected to plot.")
else:
base_width_per_plot = 4; fig_width = base_width_per_plot * num_cols; fig_height = base_width_per_plot * num_images
fig, axs = plt.subplots(num_images, num_cols, figsize=(fig_width, fig_height), squeeze=False)
for row_idx, data in enumerate(collected_data):
ax_in = axs[row_idx, 0]; ax_tgt = axs[row_idx, 1]
if data['input_pil']: ax_in.imshow(data['input_pil'])
else: ax_in.text(0.5, 0.5, 'Input Error', ha='center', va='center', color='red', transform=ax_in.transAxes)
if row_idx == 0: ax_in.set_title("Input", fontsize=14)
if data['target_pil']: ax_tgt.imshow(data['target_pil'])
else: ax_tgt.text(0.5, 0.5, 'Target Error', ha='center', va='center', color='red', transform=ax_tgt.transAxes)
if row_idx == 0: ax_tgt.set_title("Target Bokeh", fontsize=14)
col_idx = 2
for label_prefix, s_w, p_w, _ in MODELS_TO_LOAD:
label = f"{label_prefix} ($\\lambda_1$={s_w:.1f}, $\\lambda_2$={p_w:.1f})" # Lambda1=SSIM, Lambda2=LPIPS
current_ax = axs[row_idx, col_idx]
output_img = data['outputs'].get(label, None)
if output_img: current_ax.imshow(output_img)
else: current_ax.text(0.5, 0.5, 'Error/Missing', ha='center', va='center', fontsize=10, color='red', transform=current_ax.transAxes)
if row_idx == 0: title = f"{label_prefix}\n($\\lambda_1$={s_w:.1f}, $\\lambda_2$={p_w:.1f})"; current_ax.set_title(title, fontsize=14)
col_idx += 1
for ax in axs.flat: ax.set_aspect('equal', adjustable='box'); ax.axis('off')
plt.subplots_adjust(wspace=0.05, hspace=0.05); print("Displaying plot..."); plt.show()
print("\n--- Qualitative Comparison Plot Generation Complete ---")
--- Collecting data for 5 images and 4 models --- Loading images for comparison from: /content/ebb_dataset/validation/original Loading model weights for: Baseline ($\lambda_1$=1.0, $\lambda_2$=0.5) Loading model weights for: LPIPS Weighted ($\lambda_1$=1.0, $\lambda_2$=1.0) Loading model weights for: LPIPS Heavy ($\lambda_1$=1.0, $\lambda_2$=2.0) Loading model weights for: SSIM Heavy ($\lambda_1$=4.0, $\lambda_2$=0.5)
Processing Images: 100%|██████████| 5/5 [00:01<00:00, 4.60it/s]
--- Data Collection Complete --- --- Generating Consolidated Plot --- Displaying plot...
--- Qualitative Comparison Plot Generation Complete ---
6. Training loop¶
In [ ]:
# @title FINAL TRAINING LOOP VANILLA U-NET (BASED ON VBASE_3) - RESULTS SECTION
torch.manual_seed(42) # for reproducibility
np.random.seed(42)
BASE_PATH = '/content/ebb_dataset/'
DRIVE_SAVE_DIR = "/content/drive/MyDrive/Bokeh_Training_Runs"
RUN_ID = "VUNetDeep_IF64_B4_E50_ES6_LRsch3_Sub2048_Final"
MODEL_HP = {
'in_channels': 3,
'out_channels': 3,
'init_features': 64
}
DEPTH_LEVELS = 4 # deeper model for final loop due to better performance
#hyperparam.'s
TRAINING_HP = {
'learning_rate': 2e-4, # starting LR (aggressive )
'batch_size': 4,
'target_epochs': 30,
'grad_clip': 1.0,
}
TRAIN_SUBSET_SIZE = 2048 # for efficiencyy
EARLY_STOPPING_PATIENCE = 6
CHECKPOINT_FREQ = 5 # saving every 5 epochs to drive just incase runtime disconn.
# LR Scheduler Config
SCHEDULER_PATIENCE = 3 # <<< Patience = 3 for LR reduction
SCHEDULER_FACTOR = 0.2 # reduces by 20%
# Image/Evaluation Settings
IMG_SIZE = 512
EVAL_CROP_BORDER = 32
EVAL_SSIM_WIN_SIZE = 11
# Ensure Save Directory Exists
CHECKPOINT_PATH = os.path.join(DRIVE_SAVE_DIR, f"{RUN_ID}_checkpoint.pth")
BEST_MODEL_PATH = os.path.join(DRIVE_SAVE_DIR, f"{RUN_ID}_best_model.pth")
os.makedirs(DRIVE_SAVE_DIR, exist_ok=True)
print(f"Models/Checkpoints will be saved in: {DRIVE_SAVE_DIR}")
### Class definitions, some repeated for individual cell runs ###
class BokehDataset(Dataset):
def __init__(self, original_dir, bokeh_dir, transform=None):
self.original_dir=original_dir; self.bokeh_dir=bokeh_dir; self.transform=transform
if not os.path.isdir(self.original_dir): raise FileNotFoundError(f"Dir not found: {self.original_dir}")
if not os.path.isdir(self.bokeh_dir): raise FileNotFoundError(f"Dir not found: {self.bokeh_dir}")
try:
original_files=set(f for f in os.listdir(original_dir) if f.lower().endswith(('.png','.jpg','.jpeg','.JPG')))
bokeh_files=set(f for f in os.listdir(bokeh_dir) if f.lower().endswith(('.png','.jpg','.jpeg','.JPG')))
except FileNotFoundError: print(f"Warning: Error listing files."); original_files, bokeh_files = set(), set()
self.image_names=sorted(list(original_files.intersection(bokeh_files)))
if not self.image_names: print(f"Warning: No matching image pairs found in {original_dir} and {bokeh_dir}")
def __len__(self): return len(self.image_names)
def __getitem__(self, idx):
if idx >= len(self.image_names): raise IndexError("Index out of bounds")
img_name=self.image_names[idx]; original_path=os.path.join(self.original_dir, img_name); bokeh_path=os.path.join(self.bokeh_dir, img_name)
try:
if not os.path.exists(original_path) or not os.path.exists(bokeh_path): return None, None
with Image.open(original_path) as img_o, Image.open(bokeh_path) as img_b: original_image=img_o.convert('RGB'); bokeh_image=img_b.convert('RGB')
if self.transform: input_tensor=self.transform(original_image); target_tensor=self.transform(bokeh_image)
else: return original_image, bokeh_image
if not isinstance(input_tensor, torch.Tensor) or not isinstance(target_tensor, torch.Tensor): return None, None
return input_tensor, target_tensor
except Exception as e:
return None, None
class WeightedPyNetPerceptualLoss(nn.Module):
def __init__(self, fg_weight: float = 2.0, bg_weight: float = 1.0, max_levels: int = 3, level_weights: list = None, lpips_weight: float = 1.0, ssim_weight: float = 1.0):
super().__init__(); self.fg_w, self.bg_w = fg_weight, bg_weight; self.max_levels = max_levels
self.level_weights = level_weights or [1.0/(2**i) for i in range(max_levels+1)]
if 'lpips' not in globals(): raise NameError("lpips library not imported or LPIPS class not found")
self.lpips_fn = lpips.LPIPS(net='alex', verbose=False).eval(); self.lpips_weight = lpips_weight; self.ssim_weight = ssim_weight
lap = torch.tensor([[0.,1.,0.],[1.,-4.,1.],[0.,1.,0.]]); self.register_buffer('lap_kernel', lap.unsqueeze(0).unsqueeze(0))
def gaussian_pyramid(self, img: torch.Tensor, max_levels: int) -> list:
k1d = torch.tensor([1.,4.,6.,4.,1.], device=img.device) / 16.; ker = (k1d[:,None] * k1d[None,:]).unsqueeze(0).unsqueeze(0).repeat(img.size(1),1,1,1)
pyr = [img]; cur = img
for _ in range(max_levels):
try: blurred = F.conv2d(F.pad(cur, (2,2,2,2), mode='reflect'), ker, groups=img.size(1)); cur = blurred[:, :, ::2, ::2]; pyr.append(cur)
except Exception as e: print(f"Error in gaussian_pyramid conv: {e}"); break
return pyr
def generate_focus_mask(self, img: torch.Tensor, threshold: float = 0.03) -> torch.Tensor:
if img.shape[1] != 3: print("Warning: Generating focus mask from non-3-channel image.")
try: img_float = img.float(); gray = img_float.mean(dim=1, keepdim=True); lap_kernel_device = self.lap_kernel.to(gray.device); lap = F.conv2d(F.pad(gray, (1,1,1,1), mode='reflect'), lap_kernel_device); return (lap.abs() > threshold).float().to(img.device)
except Exception as e: print(f"Error generating focus mask: {e}"); return torch.ones_like(img[:, 0:1, :, :])
def forward(self, pred: torch.Tensor, target: torch.Tensor, original: torch.Tensor) -> (torch.Tensor, dict):
pred_device = pred.device; target = target.to(pred_device); original = original.to(pred_device)
l1_loss, ssim_loss = 0.0, 0.0; loss_components = {}
try:
mask_full = self.generate_focus_mask(original); pred = torch.clamp(pred, 0.0, 1.0); target = torch.clamp(target, 0.0, 1.0)
pred_pyr = self.gaussian_pyramid(pred, self.max_levels); tgt_pyr = self.gaussian_pyramid(target, self.max_levels)
if len(pred_pyr) != len(tgt_pyr): raise ValueError(f"Pyramid lengths differ: Pred {len(pred_pyr)}, Target {len(tgt_pyr)}")
for lvl, (p, t) in enumerate(zip(pred_pyr, tgt_pyr)):
if p.shape != t.shape: continue
w_lvl = self.level_weights[lvl]; mask_lvl = F.interpolate(mask_full, size=p.shape[2:], mode='nearest'); pixel_w = mask_lvl * self.fg_w + (1 - mask_lvl) * self.bg_w
l1 = (p - t).abs(); weighted_l1_map = l1 * pixel_w; current_l1_loss = w_lvl * weighted_l1_map.mean(); l1_loss += current_l1_loss; loss_components[f'l1_lvl_{lvl}'] = current_l1_loss.item()
try: ssim_val = ssim_pytorch(p, t, data_range=1.0, size_average=True); avg_w = pixel_w.mean(); current_ssim_loss = w_lvl * (1 - ssim_val) * avg_w; ssim_loss += current_ssim_loss; loss_components[f'ssim_lvl_{lvl}'] = current_ssim_loss.item()
except Exception: loss_components[f'ssim_lvl_{lvl}'] = 0.0
try: lpips_loss_val = self.lpips_fn(pred * 2.0 - 1.0, target * 2.0 - 1.0).mean(); loss_components['lpips'] = lpips_loss_val.item() * self.lpips_weight
except Exception: lpips_loss_val = torch.tensor(0.0).to(pred_device); loss_components['lpips'] = 0.0
total = l1_loss + self.ssim_weight * ssim_loss + self.lpips_weight * lpips_loss_val
loss_components['l1_multi_weighted'] = l1_loss.item(); loss_components['ssim_multi_weighted'] = ssim_loss.item() * self.ssim_weight
return total, loss_components
except Exception as e: print(f"ERROR in WeightedPyNetPerceptualLoss forward pass: {e}"); traceback.print_exc(); return torch.tensor(0.0, device=pred_device, requires_grad=True), {}
class bokeh_CNN_Deeper(nn.Module): # deeper lvel 4 CNN used for the final loop
def __init__(self, in_channels=3, out_channels=3, init_features=32):
super(bokeh_CNN_Deeper, self).__init__(); features = init_features
self.enc1 = nn.Sequential(nn.Conv2d(in_channels, features, 3, padding=1, bias=False), nn.BatchNorm2d(features), nn.ReLU(True), nn.Conv2d(features, features, 3, padding=1, bias=False), nn.BatchNorm2d(features), nn.ReLU(True)); self.pool1 = nn.MaxPool2d(2, 2)
self.enc2 = nn.Sequential(nn.Conv2d(features, features*2, 3, padding=1, bias=False), nn.BatchNorm2d(features*2), nn.ReLU(True), nn.Conv2d(features*2, features*2, 3, padding=1, bias=False), nn.BatchNorm2d(features*2), nn.ReLU(True)); self.pool2 = nn.MaxPool2d(2, 2)
self.enc3 = nn.Sequential(nn.Conv2d(features*2, features*4, 3, padding=1, bias=False), nn.BatchNorm2d(features*4), nn.ReLU(True), nn.Conv2d(features*4, features*4, 3, padding=1, bias=False), nn.BatchNorm2d(features*4), nn.ReLU(True)); self.pool3 = nn.MaxPool2d(2, 2)
self.enc4 = nn.Sequential(nn.Conv2d(features*4, features*8, 3, padding=1, bias=False), nn.BatchNorm2d(features*8), nn.ReLU(True), nn.Conv2d(features*8, features*8, 3, padding=1, bias=False), nn.BatchNorm2d(features*8), nn.ReLU(True))
self.up3 = nn.ConvTranspose2d(features*8, features*4, 2, stride=2); self.dec3 = nn.Sequential(nn.Conv2d(features*8, features*4, 3, padding=1, bias=False), nn.BatchNorm2d(features*4), nn.ReLU(True), nn.Conv2d(features*4, features*4, 3, padding=1, bias=False), nn.BatchNorm2d(features*4), nn.ReLU(True))
self.up2 = nn.ConvTranspose2d(features*4, features*2, 2, stride=2); self.dec2 = nn.Sequential(nn.Conv2d(features*4, features*2, 3, padding=1, bias=False), nn.BatchNorm2d(features*2), nn.ReLU(True), nn.Conv2d(features*2, features*2, 3, padding=1, bias=False), nn.BatchNorm2d(features*2), nn.ReLU(True))
self.up1 = nn.ConvTranspose2d(features*2, features, 2, stride=2); self.dec1 = nn.Sequential(nn.Conv2d(features*2, features, 3, padding=1, bias=False), nn.BatchNorm2d(features), nn.ReLU(True), nn.Conv2d(features, features, 3, padding=1, bias=False), nn.BatchNorm2d(features), nn.ReLU(True))
self.output_conv = nn.Conv2d(features, out_channels, 1, bias=True)
def forward(self, x):
try:
x1=self.enc1(x); p1=self.pool1(x1); x2=self.enc2(p1); p2=self.pool2(x2); x3=self.enc3(p2); p3=self.pool3(x3); x4=self.enc4(p3)
u3=self.up3(x4); dY=x3.size(2)-u3.size(2); dX=x3.size(3)-u3.size(3); u3=F.pad(u3,[dX//2,dX-dX//2,dY//2,dY-dY//2]); cat3=torch.cat([u3,x3],dim=1); d3=self.dec3(cat3)
u2=self.up2(d3); dY=x2.size(2)-u2.size(2); dX=x2.size(3)-u2.size(3); u2=F.pad(u2,[dX//2,dX-dX//2,dY//2,dY-dY//2]); cat2=torch.cat([u2,x2],dim=1); d2=self.dec2(cat2)
u1=self.up1(d2); dY=x1.size(2)-u1.size(2); dX=x1.size(3)-u1.size(3); u1=F.pad(u1,[dX//2,dX-dX//2,dY//2,dY-dY//2]); cat1=torch.cat([u1,x1],dim=1); d1=self.dec1(cat1)
out=self.output_conv(d1); return torch.sigmoid(out)
except Exception as e: print(f"Error in bokeh_CNN_Deeper forward pass: {e}"); traceback.print_exc(); raise
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"--- Running on device: {device} ---")
# single transform, no augm.
transform_no_aug = Compose([Resize((IMG_SIZE, IMG_SIZE)), ToTensor()])
print("Defined unified transform (No Augmentation).")
# datasets & loaders
train_loader = None; val_loader = None; test_loader = None
try:
if 'BASE_PATH' not in globals() or not os.path.isdir(BASE_PATH): raise ValueError("'BASE_PATH' not defined")
original_train_dir=os.path.join(BASE_PATH, 'train/original'); bokeh_train_dir=os.path.join(BASE_PATH, 'train/bokeh')
new_val_original_dir=os.path.join(BASE_PATH, 'validation/original'); new_val_bokeh_dir=os.path.join(BASE_PATH, 'validation/bokeh')
test_original_dir = os.path.join(BASE_PATH, 'test/original'); test_bokeh_dir = os.path.join(BASE_PATH, 'test/bokeh')
print("Loading full original training dataset index for subsetting...")
full_new_train_dataset = BokehDataset(original_dir=original_train_dir, bokeh_dir=bokeh_train_dir, transform=transform_no_aug)
num_actual_train_samples = len(full_new_train_dataset); assert num_actual_train_samples > 0
# training subset creation
train_subset_size = min(TRAIN_SUBSET_SIZE, num_actual_train_samples)
train_dataset_subset, _ = random_split(full_new_train_dataset, [train_subset_size, num_actual_train_samples - train_subset_size], generator=torch.Generator().manual_seed(42))
print(f"Created Training Subset of size: {len(train_dataset_subset)}")
# load FULL val. dataset
print("Loading full new 'validation' dataset...")
val_dataset = BokehDataset(original_dir=new_val_original_dir, bokeh_dir=new_val_bokeh_dir, transform=transform_no_aug)
num_actual_val_samples = len(val_dataset); assert num_actual_val_samples > 0
print(f"Found {num_actual_val_samples} pairs in 'validation'.")
# Lload FULL test dataset
if os.path.isdir(test_original_dir) and os.path.isdir(test_bokeh_dir):
print("Loading full 'test' dataset...")
test_dataset = BokehDataset(original_dir=test_original_dir, bokeh_dir=test_bokeh_dir, transform=transform_no_aug)
print(f"Found {len(test_dataset)} pairs in 'test'.")
if len(test_dataset) == 0: test_dataset = None # Handle empty test set
else: print("Test dataset directory not found."); test_dataset = None
# create dataloaders
def collate_fn_skip_none(batch): batch=list(filter(lambda x: x is not None and x[0] is not None and x[1] is not None, batch)); return torch.utils.data.dataloader.default_collate(batch) if batch else None
num_workers = 2; BATCH_SIZE = TRAINING_HP['batch_size']
train_loader = DataLoader(train_dataset_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn_skip_none, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn_skip_none)
test_loader = None
if test_dataset: test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn_skip_none)
print(f"\nDataLoaders created (Batch Size: {BATCH_SIZE}): Train ({len(train_loader)} batches), Validation ({len(val_loader)} batches){f', Test ({len(test_loader)} batches)' if test_loader else ', No Test Loader'}")
assert len(train_loader)>0 and len(val_loader)>0
except Exception as e: print(f"Error setting up Datasets/DataLoaders: {e}"); raise
# -- Model, Optimizer, Loss, Scheduler - -
# Instantiate the DEEPER model (Lvl4)
model = bokeh_CNN_Deeper(**MODEL_HP).to(device) # MODEL_HP uses init_features=64
optimizer = optim.Adam(model.parameters(), lr=TRAINING_HP['learning_rate'])
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=SCHEDULER_FACTOR, patience=SCHEDULER_PATIENCE, verbose=True)
try:
criterion = WeightedPyNetPerceptualLoss(lpips_weight=0.5, ssim_weight=1.0).to(device); lpips_eval_fn = lpips.LPIPS(net='alex', verbose=False).to(device)
print("Model, Optimizer, Loss, Scheduler instantiated.")
except Exception as e: print(f"Error instantiating components: {e}"); raise
# checkpoint loading -
start_epoch = 0; best_val_loss = float('inf'); history = {'train_loss': [], 'val_loss': [], 'val_l1_loss': []}
if os.path.exists(CHECKPOINT_PATH):
print(f"Attempting to load checkpoint: {CHECKPOINT_PATH}")
try:
checkpoint = torch.load(CHECKPOINT_PATH, map_location='cpu'); model_state_dict = checkpoint['model_state_dict']
# loading state dict for the deeper model
model.load_state_dict(model_state_dict); model.to(device);
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor): state[k] = v.to(device)
start_epoch = checkpoint['epoch'] + 1; best_val_loss = checkpoint.get('best_val_loss', float('inf'))
if 'history' in checkpoint: history = checkpoint['history']; print(f"Loaded history with {len(history.get('train_loss',[]))} epochs.")
else: history = {'train_loss': [], 'val_loss': [], 'val_l1_loss': []} # init if missing
if 'scheduler_state_dict' in checkpoint: scheduler.load_state_dict(checkpoint['scheduler_state_dict']); print("Loaded LR scheduler state.")
print(f"Resuming training from Epoch {start_epoch}, Best Val Loss: {best_val_loss:.4f}")
except RuntimeError as e: print(f"Error loading state_dict (likely model mismatch): {e}. Starting fresh."); start_epoch = 0; best_val_loss = float('inf'); history = {'train_loss': [], 'val_loss': [], 'val_l1_loss': []}
except Exception as e: print(f"Error loading checkpoint: {e}. Starting fresh."); start_epoch = 0; best_val_loss = float('inf'); history = {'train_loss': [], 'val_loss': [], 'val_l1_loss': []}
else: print("No checkpoint found. Starting training from scratch.")
# definining eval. function
def evaluate_model(model, dataloader, lpips_eval_fn, crop_border=0, ssim_win_size=11):
model.eval(); eval_device = next(model.parameters()).device; total_psnr, total_ssim_sk, total_msssim, total_lpips = 0.0, 0.0, 0.0, 0.0; lpips_msssim_count, psnr_ssim_count = 0, 0
if dataloader is None: return 0,0,0,0
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating", leave=False, disable=len(dataloader)<5):
if batch is None: continue
try: inputs, targets = batch; assert isinstance(inputs, torch.Tensor) and isinstance(targets, torch.Tensor)
except (AssertionError, TypeError, ValueError): continue
if inputs.shape[1] != 3 or targets.shape[1] != 3: continue
inputs, targets = inputs.to(eval_device), targets.to(eval_device); outputs = model(inputs)
outputs_clipped = torch.clamp(outputs, 0.0, 1.0); current_batch_size = outputs_clipped.shape[0]; outputs_lpips_input = outputs_clipped * 2.0 - 1.0; targets_lpips_input = targets * 2.0 - 1.0
try: lpips_val = lpips_eval_fn(outputs_lpips_input, targets_lpips_input.detach()); total_lpips += lpips_val.sum().item(); lpips_msssim_count += current_batch_size
except Exception: pass
try: ms_ssim_val_batch = ms_ssim(outputs_clipped, targets.detach(), data_range=1.0, size_average=False); total_msssim += ms_ssim_val_batch.sum().item()
except Exception: pass
outputs_np_full = outputs_clipped.cpu().numpy(); targets_np_full = targets.cpu().numpy()
for i in range(current_batch_size):
out_img_full_np = np.transpose(outputs_np_full[i], (1, 2, 0)); tgt_img_full_np = np.transpose(targets_np_full[i], (1, 2, 0)); h, w = out_img_full_np.shape[:2]
if crop_border > 0 and h > 2 * crop_border and w > 2 * crop_border: out_img_eval_np=out_img_full_np[crop_border:-crop_border, crop_border:-crop_border, :]; tgt_img_eval_np=tgt_img_full_np[crop_border:-crop_border, crop_border:-crop_border, :]
else: out_img_eval_np=out_img_full_np; tgt_img_eval_np=tgt_img_full_np
try:
psnr=psnr_metric_skimage(tgt_img_eval_np, out_img_eval_np, data_range=1.0); ch_sk, cw_sk = tgt_img_eval_np.shape[:2]; current_win_size=min(ssim_win_size, ch_sk, cw_sk)
if current_win_size % 2 == 0: current_win_size -= 1; current_win_size=max(3, current_win_size)
if ch_sk >= current_win_size and cw_sk >= current_win_size : ssim_val_sk=ssim_metric_skimage(tgt_img_eval_np, out_img_eval_np, channel_axis=-1, data_range=1.0, win_size=current_win_size, gaussian_weights=True, multichannel=True)
else: ssim_val_sk = np.nan
if not np.isnan(psnr) and not np.isnan(ssim_val_sk): total_psnr+=psnr; total_ssim_sk+=ssim_val_sk; psnr_ssim_count+=1
except ValueError: pass
avg_psnr=total_psnr/psnr_ssim_count if psnr_ssim_count>0 else 0; avg_ssim_sk=total_ssim_sk/psnr_ssim_count if psnr_ssim_count > 0 else 0; avg_msssim=total_msssim/lpips_msssim_count if lpips_msssim_count > 0 else 0; avg_lpips=total_lpips/lpips_msssim_count if lpips_msssim_count > 0 else 0
return avg_psnr, avg_ssim_sk, avg_msssim, avg_lpips
# - train loop --
print(f"\n--- Starting Final Training Run: {RUN_ID} ---")
patience_counter = 0; training_start_time = time.time(); epochs_completed_this_run = 0
history.setdefault('train_loss', []); history.setdefault('val_loss', []); history.setdefault('val_l1_loss', [])
try:
for epoch in range(start_epoch, TRAINING_HP['target_epochs']):
epoch_start_time = time.time()
model.train(); running_train_loss = 0.0
pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{TRAINING_HP['target_epochs']} [Train]", leave=False)
for i, batch in pbar:
if batch is None: continue
try: inputs, targets = batch; assert isinstance(inputs, torch.Tensor) and isinstance(targets, torch.Tensor)
except (AssertionError, TypeError, ValueError): continue
if inputs.shape[1] != 3 or targets.shape[1] != 3: continue
try:
inputs, targets = inputs.to(device), targets.to(device); optimizer.zero_grad(); outputs = model(inputs); total_loss, _ = criterion(pred=outputs, target=targets, original=inputs)
if torch.isnan(total_loss) or torch.isinf(total_loss): continue
total_loss.backward();
if 'GRAD_CLIP_VALUE' in globals() and GRAD_CLIP_VALUE: torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_VALUE)
optimizer.step(); running_train_loss += total_loss.item(); pbar.set_postfix({'Loss': f'{total_loss.item():.4f}'})
except Exception as batch_err: print(f"Error processing train batch {i}: {batch_err}"); traceback.print_exc(); continue
avg_train_loss = running_train_loss / len(train_loader) if len(train_loader) > 0 else 0
if len(history['train_loss']) <= epoch: history['train_loss'].append(avg_train_loss)
else: history['train_loss'][epoch] = avg_train_loss # overwrite if resuming
# - validation phase
model.eval(); running_val_loss = 0.0; running_val_l1_loss = 0.0; avg_val_l1_loss = float('nan')
if len(val_loader) == 0: print("Warning: Val loader empty."); avg_val_loss = best_val_loss
else:
with torch.no_grad():
for batch_val in tqdm(val_loader, desc=f"Epoch {epoch+1} [Validation]", leave=False):
if batch_val is None: continue
try: inputs_val, targets_val = batch_val; assert isinstance(inputs_val, torch.Tensor) and isinstance(targets_val, torch.Tensor)
except (AssertionError, TypeError, ValueError): continue
if inputs_val.shape[1] != 3 or targets_val.shape[1] != 3: continue
try:
inputs_val, targets_val = inputs_val.to(device), targets_val.to(device); outputs_val = model(inputs_val)
val_total_loss, val_loss_components = criterion(pred=outputs_val, target=targets_val, original=inputs_val)
if torch.isnan(val_total_loss) or torch.isinf(val_total_loss): continue
running_val_loss += val_total_loss.item(); running_val_l1_loss += val_loss_components.get('l1_multi_weighted', 0.0)
except Exception as val_batch_err: print(f"Error processing validation batch: {val_batch_err}"); continue
avg_val_loss = running_val_loss / len(val_loader) if len(val_loader) > 0 else float('inf')
avg_val_l1_loss = running_val_l1_loss / len(val_loader) if len(val_loader) > 0 else float('nan')
# Store history
history.setdefault('val_loss', []).append(avg_val_loss)
history.setdefault('val_l1_loss', []).append(avg_val_l1_loss)
if len(history['val_loss']) > epoch + 1 : history['val_loss'][epoch] = avg_val_loss
if len(history['val_l1_loss']) > epoch + 1 : history['val_l1_loss'][epoch] = avg_val_l1_loss
epochs_completed_this_run += 1; epoch_duration = time.time() - epoch_start_time
print(f"Epoch {epoch+1}/{TRAINING_HP['target_epochs']} - Train Loss: {avg_train_loss:.4f} - Val Loss: {avg_val_loss:.4f} - Val L1: {avg_val_l1_loss:.4f} - Time: {epoch_duration:.2f}s")
# - - LR scheduler step
scheduler.step(avg_val_loss)
# -- checpointing & early stopping
is_best = avg_val_loss < best_val_loss
if is_best:
best_val_loss = avg_val_loss; patience_counter = 0; print(f" New best validation loss: {best_val_loss:.4f}. Saving best model...")
try: torch.save(model.state_dict(), BEST_MODEL_PATH)
except Exception as e: print(f" ERROR saving best model: {e}")
else: patience_counter += 1; print(f" Validation loss did not improve. ES Patience: {patience_counter}/{EARLY_STOPPING_PATIENCE}")
# save checkpoint periodically OR if it's the best model OR if it's the last epoch - to my google drive
if (epoch + 1) % CHECKPOINT_FREQ == 0 or epoch == TRAINING_HP['target_epochs'] - 1 or is_best:
print(f" Saving checkpoint at epoch {epoch+1}...")
checkpoint_data = {'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_val_loss': best_val_loss, 'history': history}
try: torch.save(checkpoint_data, CHECKPOINT_PATH)
except Exception as e: print(f" ERROR saving checkpoint: {e}")
if patience_counter >= EARLY_STOPPING_PATIENCE: print(f"\nEarly stopping triggered after epoch {epoch+1}."); break
except KeyboardInterrupt: print("\nTraining stopped manually."); print("Saving final checkpoint..."); checkpoint_data = {'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_val_loss': best_val_loss, 'history': history}; torch.save(checkpoint_data, CHECKPOINT_PATH); print("Checkpoint saved.")
except Exception as train_err: print(f"\n!!! ERROR during training loop: {train_err} !!!"); traceback.print_exc()
# end of training loop
total_training_time_mins = (time.time() - training_start_time) / 60.0
print(f"\n--- Training Finished ({total_training_time_mins:.2f} mins over {epochs_completed_this_run} epochs this run) ---")
# Final eval on test set
print("\n--- Evaluating Best Model on DEDICATED Test Set ---")
final_test_loss, final_test_l1 = float('nan'), float('nan'); final_test_psnr, final_test_ssim, final_test_msssim, final_test_lpips = [float('nan')] * 4; avg_inference_time_ms = float('nan')
if test_loader is None: print("ERROR: Test loader not available. Skipping test set evaluation.")
elif not os.path.exists(BEST_MODEL_PATH): print("ERROR: Best model file not found. Cannot evaluate.")
else:
print(f"Loading best model from: {BEST_MODEL_PATH}")
try:
# instantiate the DEEPER model for evaluation
best_model = bokeh_CNN_Deeper(**MODEL_HP).to(device)
best_model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=device)); best_model.eval()
# calc. standard metrics using evaluate_model on TEST data
final_test_psnr, final_test_ssim, final_test_msssim, final_test_lpips = evaluate_model(best_model, test_loader, lpips_eval_fn, EVAL_CROP_BORDER, EVAL_SSIM_WIN_SIZE)
# calculate Test Loss (Total and L1 component) & Inference Time on TEST data
test_loss_total = 0.0; test_loss_l1 = 0.0; inference_times = []
with torch.no_grad():
for batch in tqdm(test_loader, desc="Final Test Eval & Inference Timing"):
if batch is None: continue
try: inputs, targets = batch; assert isinstance(inputs, torch.Tensor) and isinstance(targets, torch.Tensor)
except (AssertionError, TypeError, ValueError): continue
if inputs.shape[1] != 3 or targets.shape[1] != 3: continue
try:
inputs, targets = inputs.to(device), targets.to(device)
torch.cuda.synchronize(); inf_start = time.time(); outputs = best_model(inputs); torch.cuda.synchronize(); inf_end = time.time()
if inputs.size(0) > 0: inference_times.append((inf_end - inf_start) * 1000 / inputs.size(0))
batch_loss, batch_loss_components = criterion(pred=outputs, target=targets, original=inputs)
if torch.isnan(batch_loss): print("Warning: NaN test loss detected."); continue
test_loss_total += batch_loss.item(); test_loss_l1 += batch_loss_components.get('l1_multi_weighted', 0.0)
except Exception as test_batch_err: print(f"Error processing test batch: {test_batch_err}"); continue
final_test_loss = test_loss_total / len(test_loader) if len(test_loader) > 0 else float('nan')
final_test_l1 = test_loss_l1 / len(test_loader) if len(test_loader) > 0 else float('nan')
avg_inference_time_ms = sum(inference_times) / len(inference_times) if inference_times else float('nan')
except Exception as e: print(f"Error during final test evaluation: {e}"); traceback.print_exc()
# plotting train & val loss
print("\n--- Plotting Losses ---")
# use full history potentially loaded from checkpoint
epochs_in_history = len(history.get('train_loss', []))
if epochs_in_history > 0:
# Adjust range calculation for potential resumes
actual_start_epoch = (start_epoch - epochs_in_history + epochs_completed_this_run) if start_epoch > 0 else 1
epochs_range = range(actual_start_epoch, actual_start_epoch + epochs_in_history)
plot_train_loss = history['train_loss']
plot_val_loss = history['val_loss']
if len(plot_train_loss) != len(plot_val_loss): # handle potential mismatch if stopped mid-epoch duringa load
print("Warning: History length mismatch for plotting losses.")
min_len = min(len(plot_train_loss), len(plot_val_loss))
plot_train_loss = plot_train_loss[:min_len]
plot_val_loss = plot_val_loss[:min_len]
epochs_range = range(actual_start_epoch, actual_start_epoch + min_len)
plt.figure(figsize=(10, 5));
if plot_train_loss: plt.plot(epochs_range, plot_train_loss, label='Training Loss', marker='.')
if plot_val_loss: plt.plot(epochs_range, plot_val_loss, label='Validation Loss', marker='.')
plt.xlabel('Epochs'); plt.ylabel('Loss'); plt.title(f'Training and Validation Loss ({RUN_ID})');
if plot_train_loss or plot_val_loss: plt.legend();
plt.grid(True);
if len(epochs_range) > 0: plt.xticks(list(epochs_range)[::max(1, len(epochs_range)//10)])
plot_save_path = os.path.join(DRIVE_SAVE_DIR, f"{RUN_ID}_loss_plot.png")
try: plt.savefig(plot_save_path); print(f"Loss plot saved to: {plot_save_path}")
except Exception as e: print(f"Error saving plot: {e}")
plt.show()
else: print("No training history found to plot.")
# - final metrics
print("\n--- Final Performance Report ---")
print(f"Run ID: {RUN_ID}"); print(f"Dataset: Train Subset ({TRAIN_SUBSET_SIZE}), Full Validation, Full Test (No Augmentation)")
total_epochs_trained = start_epoch + epochs_completed_this_run; print(f"Total Training Time: {total_training_time_mins:.2f} minutes ({epochs_completed_this_run} epochs this run / {total_epochs_trained} total epochs)"); print(f"Best Val Loss Achieved: {best_val_loss:.4f}")
final_train_loss = history.get('train_loss', [float('nan')])[-1]; print(f"Final Avg Training Loss (at end of run): {final_train_loss:.4f}")
print("-" * 40); print("FINAL Test Set Performance (Best Model):")
print(f" Avg Test Loss (Total): {final_test_loss:.4f}"); print(f" Avg Test L1 Loss Comp: {final_test_l1:.4f}")
print(f" PSNR: {final_test_psnr:.2f} dB"); print(f" SSIM (skimage): {final_test_ssim:.4f}")
print(f" MS-SSIM: {final_test_msssim:.4f}"); print(f" LPIPS (AlexNet): {final_test_lpips:.4f}")
print(f" Avg Inference Time: {avg_inference_time_ms:.2f} ms/image (T4 GPU)"); print("-" * 40)
print(f"\nBest model (based on validation loss) saved at: {BEST_MODEL_PATH}"); print(f"Final periodic checkpoint saved at: {CHECKPOINT_PATH}")
Models/Checkpoints will be saved in: /content/drive/MyDrive/Bokeh_Training_Runs --- Running on device: cuda --- Defined unified transform (No Augmentation). Loading full original training dataset index for subsetting... Created Training Subset of size: 2048 Loading full new 'validation' dataset... Found 938 pairs in 'validation'. Loading full 'test' dataset... Found 200 pairs in 'test'. DataLoaders created (Batch Size: 4): Train (512 batches), Validation (235 batches), Test (50 batches)
/usr/local/lib/python3.11/dist-packages/torch/optim/lr_scheduler.py:62: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate. warnings.warn( /usr/local/lib/python3.11/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( /usr/local/lib/python3.11/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights. warnings.warn(msg) Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth 100%|██████████| 233M/233M [00:00<00:00, 247MB/s]
Model, Optimizer, Loss, Scheduler instantiated. No checkpoint found. Starting training from scratch. --- Starting Final Training Run: VUNetDeep_IF64_B4_E50_ES6_LRsch3_Sub2048_Final ---
Epoch 1/30 - Train Loss: 0.9335 - Val Loss: 0.7696 - Val L1: 0.1488 - Time: 495.52s New best validation loss: 0.7696. Saving best model... Saving checkpoint at epoch 1...
Epoch 2/30 - Train Loss: 0.8598 - Val Loss: 0.7759 - Val L1: 0.1581 - Time: 497.61s Validation loss did not improve. ES Patience: 1/6
Epoch 3/30 - Train Loss: 0.8276 - Val Loss: 0.7151 - Val L1: 0.1444 - Time: 498.66s New best validation loss: 0.7151. Saving best model... Saving checkpoint at epoch 3...
Epoch 4/30 - Train Loss: 0.7969 - Val Loss: 0.7084 - Val L1: 0.1441 - Time: 500.55s New best validation loss: 0.7084. Saving best model... Saving checkpoint at epoch 4...
Epoch 5/30 - Train Loss: 0.7638 - Val Loss: 0.6853 - Val L1: 0.1427 - Time: 501.07s New best validation loss: 0.6853. Saving best model... Saving checkpoint at epoch 5...
Epoch 6/30 - Train Loss: 0.7442 - Val Loss: 0.6700 - Val L1: 0.1426 - Time: 500.86s New best validation loss: 0.6700. Saving best model... Saving checkpoint at epoch 6...
Epoch 7/30 - Train Loss: 0.7286 - Val Loss: 0.6490 - Val L1: 0.1397 - Time: 501.65s New best validation loss: 0.6490. Saving best model... Saving checkpoint at epoch 7...
Epoch 8/30 - Train Loss: 0.7130 - Val Loss: 0.6462 - Val L1: 0.1409 - Time: 501.02s New best validation loss: 0.6462. Saving best model... Saving checkpoint at epoch 8...
Epoch 9/30 - Train Loss: 0.7073 - Val Loss: 0.6640 - Val L1: 0.1424 - Time: 500.57s Validation loss did not improve. ES Patience: 1/6
Epoch 10/30 - Train Loss: 0.6957 - Val Loss: 0.6292 - Val L1: 0.1363 - Time: 501.57s New best validation loss: 0.6292. Saving best model... Saving checkpoint at epoch 10...
Epoch 11/30 - Train Loss: 0.6935 - Val Loss: 0.6303 - Val L1: 0.1375 - Time: 500.54s Validation loss did not improve. ES Patience: 1/6
Epoch 12/30 - Train Loss: 0.6828 - Val Loss: 0.6256 - Val L1: 0.1385 - Time: 500.77s New best validation loss: 0.6256. Saving best model... Saving checkpoint at epoch 12...
Epoch 13/30 - Train Loss: 0.6758 - Val Loss: 0.6289 - Val L1: 0.1374 - Time: 501.46s Validation loss did not improve. ES Patience: 1/6
Epoch 14/30 - Train Loss: 0.6742 - Val Loss: 0.6671 - Val L1: 0.1514 - Time: 501.68s Validation loss did not improve. ES Patience: 2/6
Epoch 15/30 - Train Loss: 0.6696 - Val Loss: 0.6418 - Val L1: 0.1412 - Time: 501.48s Validation loss did not improve. ES Patience: 3/6 Saving checkpoint at epoch 15...
Epoch 16/30 - Train Loss: 0.6669 - Val Loss: 0.6392 - Val L1: 0.1415 - Time: 501.70s Validation loss did not improve. ES Patience: 4/6
Epoch 17/30 - Train Loss: 0.6383 - Val Loss: 0.6056 - Val L1: 0.1342 - Time: 502.07s New best validation loss: 0.6056. Saving best model... Saving checkpoint at epoch 17...
Epoch 18/30 - Train Loss: 0.6299 - Val Loss: 0.6163 - Val L1: 0.1419 - Time: 502.54s Validation loss did not improve. ES Patience: 1/6
Epoch 19/30 - Train Loss: 0.6297 - Val Loss: 0.6060 - Val L1: 0.1355 - Time: 501.84s Validation loss did not improve. ES Patience: 2/6
Epoch 20/30 - Train Loss: 0.6255 - Val Loss: 0.6040 - Val L1: 0.1349 - Time: 501.88s New best validation loss: 0.6040. Saving best model... Saving checkpoint at epoch 20...
Epoch 21/30 - Train Loss: 0.6238 - Val Loss: 0.5965 - Val L1: 0.1324 - Time: 502.66s New best validation loss: 0.5965. Saving best model... Saving checkpoint at epoch 21...
Epoch 22/30 - Train Loss: 0.6200 - Val Loss: 0.6025 - Val L1: 0.1352 - Time: 501.96s Validation loss did not improve. ES Patience: 1/6
Epoch 23/30 - Train Loss: 0.6208 - Val Loss: 0.5988 - Val L1: 0.1347 - Time: 502.83s Validation loss did not improve. ES Patience: 2/6
Epoch 24/30 - Train Loss: 0.6194 - Val Loss: 0.5932 - Val L1: 0.1326 - Time: 502.02s New best validation loss: 0.5932. Saving best model... Saving checkpoint at epoch 24...
Epoch 25/30 - Train Loss: 0.6170 - Val Loss: 0.5988 - Val L1: 0.1327 - Time: 501.62s Validation loss did not improve. ES Patience: 1/6 Saving checkpoint at epoch 25...
Epoch 26/30 - Train Loss: 0.6163 - Val Loss: 0.6025 - Val L1: 0.1358 - Time: 503.12s Validation loss did not improve. ES Patience: 2/6
Epoch 27/30 - Train Loss: 0.6149 - Val Loss: 0.5970 - Val L1: 0.1338 - Time: 501.44s Validation loss did not improve. ES Patience: 3/6
Epoch 28/30 - Train Loss: 0.6144 - Val Loss: 0.5949 - Val L1: 0.1323 - Time: 502.40s Validation loss did not improve. ES Patience: 4/6
Epoch 29/30 - Train Loss: 0.6077 - Val Loss: 0.5883 - Val L1: 0.1330 - Time: 502.91s New best validation loss: 0.5883. Saving best model... Saving checkpoint at epoch 29...
Epoch 30/30 - Train Loss: 0.6064 - Val Loss: 0.5923 - Val L1: 0.1324 - Time: 503.25s Validation loss did not improve. ES Patience: 1/6 Saving checkpoint at epoch 30... --- Training Finished (250.75 mins over 30 epochs this run) --- --- Evaluating Best Model on DEDICATED Test Set --- Loading best model from: /content/drive/MyDrive/Bokeh_Training_Runs/VUNetDeep_IF64_B4_E50_ES6_LRsch3_Sub2048_Final_best_model.pth
Final Test Eval & Inference Timing: 100%|██████████| 50/50 [00:15<00:00, 3.21it/s]
--- Plotting Losses --- Loss plot saved to: /content/drive/MyDrive/Bokeh_Training_Runs/VUNetDeep_IF64_B4_E50_ES6_LRsch3_Sub2048_Final_loss_plot.png
--- Final Performance Report --- Run ID: VUNetDeep_IF64_B4_E50_ES6_LRsch3_Sub2048_Final Dataset: Train Subset (2048), Full Validation, Full Test (No Augmentation) Total Training Time: 250.75 minutes (30 epochs this run / 30 total epochs) Best Val Loss Achieved: 0.5883 Final Avg Training Loss (at end of run): 0.6064 ---------------------------------------- FINAL Test Set Performance (Best Model): Avg Test Loss (Total): 0.5512 Avg Test L1 Loss Comp: 0.1321 PSNR: 24.12 dB SSIM (skimage): 0.8568 MS-SSIM: 0.8969 LPIPS (AlexNet): 0.1532 Avg Inference Time: 65.22 ms/image (T4 GPU) ---------------------------------------- Best model (based on validation loss) saved at: /content/drive/MyDrive/Bokeh_Training_Runs/VUNetDeep_IF64_B4_E50_ES6_LRsch3_Sub2048_Final_best_model.pth Final periodic checkpoint saved at: /content/drive/MyDrive/Bokeh_Training_Runs/VUNetDeep_IF64_B4_E50_ES6_LRsch3_Sub2048_Final_checkpoint.pth
7. Miscellaneous¶
In [ ]:
# @title BROWSE VALIDATION SET IMAGES for selecting in qualitative plots to match up with results PyNet generated (since their GH code won't run)
import random
BASE_PATH = '/content/ebb_dataset/'
VALIDATION_ORIGINAL_DIR = os.path.join(BASE_PATH, 'validation/original')
VALIDATION_BOKEH_DIR = os.path.join(BASE_PATH, 'validation/bokeh')
NUM_TO_SHOW = 10
RANDOM_SAMPLE = True
SAMPLE_SEED = 87
def get_image_pairs(orig_dir, bokeh_dir):
if not os.path.isdir(orig_dir) or not os.path.isdir(bokeh_dir):
print(f"ERROR: Directory not found: {orig_dir} or {bokeh_dir}")
return []
try:
original_files = set(f for f in os.listdir(orig_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG')))
bokeh_files = set(f for f in os.listdir(bokeh_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG')))
image_names = sorted(list(original_files.intersection(bokeh_files)))
return image_names
except Exception as e:
print(f"Error listing files: {e}")
return []
print("Loading validation image list...")
validation_image_names = get_image_pairs(VALIDATION_ORIGINAL_DIR, VALIDATION_BOKEH_DIR)
num_available = len(validation_image_names)
if num_available == 0:
print("ERROR: No image pairs found in the new validation directory.")
else:
print(f"Found {num_available} total pairs in the validation set.")
indices_to_show = []
if num_available > 0:
if RANDOM_SAMPLE:
random.seed(SAMPLE_SEED)
indices_all = list(range(num_available))
random.shuffle(indices_all)
indices_to_show = indices_all[:min(NUM_TO_SHOW, num_available)]
print(f"Showing {len(indices_to_show)} randomly selected images (Seed: {SAMPLE_SEED}).")
else:
indices_to_show = list(range(min(NUM_TO_SHOW, num_available)))
print(f"Showing first {len(indices_to_show)} images.")
print("-" * 30)
for i in indices_to_show:
img_filename = validation_image_names[i]
original_path = os.path.join(VALIDATION_ORIGINAL_DIR, img_filename)
bokeh_path = os.path.join(VALIDATION_BOKEH_DIR, img_filename)
try:
if not os.path.exists(original_path) or not os.path.exists(bokeh_path):
print(f"Warning: File missing for index {i}, skipping {img_filename}")
continue
with Image.open(original_path) as img_o, Image.open(bokeh_path) as img_b:
original_pil = img_o.convert('RGB')
bokeh_pil = img_b.convert('RGB')
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
fig.suptitle(f"Validation Set - Index: {i} | Filename: {img_filename}", fontsize=12)
axs[0].imshow(original_pil); axs[0].set_title("Original"); axs[0].axis('off')
axs[1].imshow(bokeh_pil); axs[1].set_title("Target Bokeh"); axs[1].axis('off')
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()
print("-" * 30)
except Exception as e:
print(f"Error displaying image at index {i} (Filename: {img_filename}): {e}")
print("-" * 30)
In [ ]:
# @title CHECK EBB! DATASET STRUCTURE AND COUNTS
BASE_PATH = '/content/ebb_dataset/'
EXPECTED_SPLITS = ['train', 'validation', 'test']
EXPECTED_SUBDIRS = ['original', 'bokeh']
IMG_EXTENSIONS = ('*.png', '*.jpg', '*.jpeg', '*.JPG', '*.JPEG')
print(f"--- Checking Dataset Structure in: {BASE_PATH} ---")
if not os.path.isdir(BASE_PATH):
print(f"ERROR: Base path '{BASE_PATH}' does not exist or is not a directory.")
else:
found_splits_info = {}
# loop through the expected splits
for split in EXPECTED_SPLITS:
split_path = os.path.join(BASE_PATH, split)
print(f"\nChecking Split: '{split}' (Path: {split_path})")
if not os.path.isdir(split_path):
print(f" -> Status: DIRECTORY NOT FOUND.")
found_splits_info[split] = {'exists': False}
continue
found_splits_info[split] = {'exists': True, 'subdirs': {}}
original_files = set()
bokeh_files = set()
# loop through expected subdirs.
for subdir in EXPECTED_SUBDIRS:
subdir_path = os.path.join(split_path, subdir)
print(f" Checking Subdir: '{subdir}' (Path: {subdir_path})")
if not os.path.isdir(subdir_path):
print(f" -> Status: SUBDIRECTORY NOT FOUND.")
found_splits_info[split]['subdirs'][subdir] = {'exists': False, 'count': 0, 'files': set()}
continue
# count image files n.b. FOUND PACKAGE GLOB: using it for multiple extensions
image_paths = []
for ext in IMG_EXTENSIONS:
image_paths.extend(glob.glob(os.path.join(subdir_path, ext)))
file_count = len(image_paths)
filenames = set(os.path.basename(p) for p in image_paths)
print(f" -> Status: Found. Image Files Counted: {file_count}")
found_splits_info[split]['subdirs'][subdir] = {'exists': True, 'count': file_count, 'files': filenames}
if subdir == 'original':
original_files = filenames
elif subdir == 'bokeh':
bokeh_files = filenames
# compare original vs bokeh within the split
if found_splits_info[split]['subdirs'].get('original', {}).get('exists') and \
found_splits_info[split]['subdirs'].get('bokeh', {}).get('exists'):
matching_pairs = original_files.intersection(bokeh_files)
num_matching_pairs = len(matching_pairs)
print(f" Comparison: Found {num_matching_pairs} matching filenames between 'original' and 'bokeh'.")
original_only = original_files - bokeh_files
if original_only:
print(f" Warning: {len(original_only)} files in 'original' but not 'bokeh'. Example: {list(original_only)[0] if original_only else 'N/A'}")
bokeh_only = bokeh_files - original_files
if bokeh_only:
print(f" Warning: {len(bokeh_only)} files in 'bokeh' but not 'original'. Example: {list(bokeh_only)[0] if bokeh_only else 'N/A'}")
found_splits_info[split]['paired_count'] = num_matching_pairs
else:
found_splits_info[split]['paired_count'] = 0
print(f" Comparison: Cannot compare pairs as one or both subdirectories are missing.")
# summary
print("\n--- Dataset Structure Summary ---")
total_train_val_pairs = 0
test_pairs = 0
for split, info in found_splits_info.items():
if info['exists']:
count = info.get('paired_count', 0)
print(f"- Split '{split}': FOUND ({count} paired images)")
if split == 'train' or split == 'validation':
total_train_val_pairs += count
elif split == 'test':
test_pairs = count
else:
print(f"- Split '{split}': NOT FOUND")
print("-" * 30)
print(f"Total combined Train+Validation pairs found: {total_train_val_pairs} (Expected ~4800)")
print(f"Total Test pairs found: {test_pairs} (Expected 200)")
print("-" * 30)
--- Checking Dataset Structure in: /content/ebb_dataset/ ---
Checking Split: 'train' (Path: /content/ebb_dataset/train)
Checking Subdir: 'original' (Path: /content/ebb_dataset/train/original)
-> Status: Found. Image Files Counted: 3756
Checking Subdir: 'bokeh' (Path: /content/ebb_dataset/train/bokeh)
-> Status: Found. Image Files Counted: 3756
Comparison: Found 3756 matching filenames between 'original' and 'bokeh'.
Checking Split: 'validation' (Path: /content/ebb_dataset/validation)
Checking Subdir: 'original' (Path: /content/ebb_dataset/validation/original)
-> Status: Found. Image Files Counted: 938
Checking Subdir: 'bokeh' (Path: /content/ebb_dataset/validation/bokeh)
-> Status: Found. Image Files Counted: 938
Comparison: Found 938 matching filenames between 'original' and 'bokeh'.
Checking Split: 'test' (Path: /content/ebb_dataset/test)
Checking Subdir: 'original' (Path: /content/ebb_dataset/test/original)
-> Status: Found. Image Files Counted: 200
Checking Subdir: 'bokeh' (Path: /content/ebb_dataset/test/bokeh)
-> Status: Found. Image Files Counted: 200
Comparison: Found 200 matching filenames between 'original' and 'bokeh'.
--- Dataset Structure Summary ---
- Split 'train': FOUND (3756 paired images)
- Split 'validation': FOUND (938 paired images)
- Split 'test': FOUND (200 paired images)
------------------------------
Total combined Train+Validation pairs found: 4694 (Expected ~4800)
Total Test pairs found: 200 (Expected 200)
------------------------------
Explored More Complex Attention Architectures (not in final report)¶
In [ ]:
# @title U-Net with Transformer Encoder Block to address potential bottleneck
# helper Module: positional encoding
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
"""
Args:
d_model: the dimension of the embeddings.
dropout: the dropout rate.
max_len: the maximum length of the input sequences.
"""
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.permute(1, 0, 2))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor, shape [batch_size, seq_len, embedding_dim]
"""
# x.size(1) is the sequence length
if x.size(1) > self.pe.size(1):
raise ValueError(f"Input sequence length ({x.size(1)}) exceeds PositionalEncoding max_len ({self.pe.size(1)})")
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
# MAIN Module: U-Net with Transformer Bottleneck
class UNetWithTransformerBottleneck(nn.Module):
def __init__(self, in_channels=3, out_channels=3, init_features=32,
d_model=256, nhead=8, num_transformer_layers=1, dim_feedforward=1024,
transformer_dropout=0.1, pos_encoding_dropout=0.1, pos_encoding_max_len=5000):
"""
Args:
in_channels: Number of input channels (usually 3 for RGB).
out_channels: Number of output channels (usually 3 for RGB bokeh).
init_features: Number of features in the first convolutional layer.
d_model: Embedding dimension for the Transformer bottleneck.
nhead: Number of attention heads in the Transformer. Must divide d_model.
num_transformer_layers: Number of Transformer encoder layers in the bottleneck.
dim_feedforward: Dimension of the feedforward network in the Transformer.
transformer_dropout: Dropout rate within the Transformer layer(s).
pos_encoding_dropout: Dropout rate for positional encoding.
pos_encoding_max_len: Maximum sequence length for positional encoding buffer.
"""
super().__init__()
features = init_features
if d_model % nhead != 0:
raise ValueError(f"d_model ({d_model}) must be divisible by nhead ({nhead})")
# encoder
self.enc1 = self.unet_block(in_channels, features, name="enc1")
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.enc2 = self.unet_block(features, features * 2, name="enc2")
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.enc3 = self.unet_block(features * 2, features * 4, name="enc3")
# N.B: We stop the CNN encoder here
# bottleneck interface (CNN to Transformer)
self.input_proj = nn.Conv2d(features * 4, d_model, kernel_size=1) # Project to d_model
self.positional_encoding = PositionalEncoding(d_model, pos_encoding_dropout, pos_encoding_max_len)
# transformer bottleneck
try:
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=transformer_dropout,
activation=F.relu,
batch_first=True, # Input shape: (Batch, SeqLen, Features)
norm_first=True # applyin norm before attention/FFN as often more stable
)
except TypeError: # handle older pytorch versions
print("Warning: norm_first=True not supported in this PyTorch version. Using default (norm_first=False).")
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=transformer_dropout,
activation=F.relu,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_transformer_layers)
# bottleneck interface (Transformer back to CNN)
# NO explicit layer needed, just reshaping in forward pass.
# output channels from bottleneck will be d_model.
# DECODER
# Input channels for dec2 = d_model (from bottleneck) + features*2 (from enc2 skip)
self.up2 = nn.ConvTranspose2d(d_model, features * 2, kernel_size=2, stride=2)
self.dec2 = self.unet_block((features * 2) * 2, features * 2, name="dec2") # features*2 (up) + features*2 (skip)
# Input channels for dec1 = features*2 (from dec2) + features*1 (from enc1 skip)
self.up1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)
self.dec1 = self.unet_block(features * 2, features, name="dec1") # features (up) + features (skip)
self.output_conv = nn.Conv2d(features, out_channels, kernel_size=1)
def forward(self, x):
# Encoder
enc1_out = self.enc1(x)
pool1_out = self.pool1(enc1_out)
enc2_out = self.enc2(pool1_out)
pool2_out = self.pool2(enc2_out)
enc3_out = self.enc3(pool2_out)
# Bottleneck
proj_features = self.input_proj(enc3_out) # (B, d_model, H_b, W_b)
B, C_dmodel, H_b, W_b = proj_features.shape
# flatten spatial dimensions & permute for Transformer
# (B, d_model, H_b * W_b) -> (B, H_b * W_b, d_model)
flattened_features = proj_features.flatten(2).permute(0, 2, 1)
seq_len = H_b * W_b
pos_encoded_features = self.positional_encoding(flattened_features)
transformer_output = self.transformer_encoder(pos_encoded_features) # (B, SeqLen, d_model)
# reshape back to spatial format
# (B, SeqLen, d_model) -> (B, d_model, SeqLen) -> (B, d_model, H_b, W_b)
reshaped_output = transformer_output.permute(0, 2, 1).reshape(B, C_dmodel, H_b, W_b)
# DECODER
# Decoder lvl2
up2_out = self.up2(reshaped_output)
# handles potential size mismatch with skip connection (enc2_out)
diffY = enc2_out.size()[2] - up2_out.size()[2]
diffX = enc2_out.size()[3] - up2_out.size()[3]
padded_up2_out = F.pad(up2_out, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
cat2 = torch.cat([padded_up2_out, enc2_out], dim=1)
dec2_out = self.dec2(cat2)
# Decoder lvl1
up1_out = self.up1(dec2_out)
# handle potential size mismatch with skip connection (enc1_out)
diffY = enc1_out.size()[2] - up1_out.size()[2]
diffX = enc1_out.size()[3] - up1_out.size()[3]
padded_up1_out = F.pad(up1_out, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
cat1 = torch.cat([padded_up1_out, enc1_out], dim=1)
dec1_out = self.dec1(cat1)
# final output
output = self.output_conv(dec1_out)
return torch.sigmoid(output)
@staticmethod
def unet_block(in_channels, features, name):
# standard U-Net block: Conv -> BN -> ReLU -> Conv -> BN -> ReLU
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(num_features=features),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(num_features=features),
nn.ReLU(inplace=True),
)
# example Instantiation & Test
print("\n--- Running Basic Model Instantiation Test ---")
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device for test: {device}")
img_size_test = 128 # Example input size for test
batch_size_test = 2
d_model_test = 128
nhead_test = 4
init_features_test = 32
bottleneck_side_test = img_size_test // (2**2)
max_len_estimate_test = bottleneck_side_test * bottleneck_side_test + 10
test_model_instance = UNetWithTransformerBottleneck(
in_channels=3,
out_channels=3,
init_features=init_features_test,
d_model=d_model_test,
nhead=nhead_test,
num_transformer_layers=2,
dim_feedforward=d_model_test * 4,
pos_encoding_max_len=max_len_estimate_test
).to(device)
print(f"Model instantiated successfully on {device}.")
# test w dummy input
dummy_input_test = torch.randn(batch_size_test, 3, img_size_test, img_size_test).to(device)
print(f"Dummy input shape: {dummy_input_test.shape}")
with torch.no_grad():
output_test = test_model_instance(dummy_input_test)
print(f"Dummy output shape: {output_test.shape}")
print("Forward pass completed successfully.")
total_params = sum(p.numel() for p in test_model_instance.parameters())
trainable_params = sum(p.numel() for p in test_model_instance.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
except Exception as e:
print(f"\nAn ERROR occurred during the instantiation test: {e}")
import traceback
traceback.print_exc()
print("--- End of Basic Model Instantiation Test ---\n")
lpips package found. pytorch-msssim package found. --- Running Basic Model Instantiation Test --- Using device for test: cuda
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/transformer.py:385: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True warnings.warn(
Model instantiated successfully on cuda. Dummy input shape: torch.Size([2, 3, 128, 128]) Dummy output shape: torch.Size([2, 3, 128, 128]) Forward pass completed successfully. Total parameters: 880,291 Trainable parameters: 880,291 --- End of Basic Model Instantiation Test ---
In [ ]:
# @title U-Net where we refine features coming thorugh skip connections (using Spatial attention) before concatenated in the decoder - Unet with AttentiveSkips
# helper module: Spatial Attention Block
class SpatialAttention(nn.Module):
"""
Spatial Attention Module (inspired by CBAM).
Applies AvgPool and MaxPool across channels, concatenates them,
passes through a Conv layer to generate a spatial attention map.
"""
def __init__(self, kernel_size=7):
"""
Args:
kernel_size: Kernel size for the convolution layer.
"""
super().__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
# input has 2 channels (AvgPool + MaxPool)
self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
"""
Args:
x: Input feature map (B, C, H, W)
Returns:
Feature map refined by spatial attention (B, C, H, W)
"""
# pool across the channel dimension
avg_out = torch.mean(x, dim=1, keepdim=True) # (B, 1, H, W)
max_out, _ = torch.max(x, dim=1, keepdim=True) # (B, 1, H, W)
# concatenate pooled features
pooled = torch.cat([avg_out, max_out], dim=1) # (B, 2, H, W)
# gen. spatial attention map
attention_map = self.conv(pooled) # (B, 1, H, W)
attention_map = self.sigmoid(attention_map) # Values between 0 and 1
# apply attention map to the original input feature map
return x * attention_map # Element-wise multiplication, broadcasts
# MAIN module: U-Net with Attentive Skips
class UNetWithAttentiveSkips(nn.Module):
def __init__(self, in_channels=3, out_channels=3, init_features=32, attention_kernel_size=7):
"""
Args:
in_channels: Number of input channels (usually 3 for RGB).
out_channels: Number of output channels (usually 3 for RGB bokeh).
init_features: Number of features in the first convolutional layer.
attention_kernel_size: Kernel size for the SpatialAttention modules.
"""
super().__init__()
features = init_features
# ENCODERR
self.enc1_conv = self.unet_block(in_channels, features, name="enc1_conv")
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.enc2_conv = self.unet_block(features, features * 2, name="enc2_conv")
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.enc3_conv = self.unet_block(features * 2, features * 4, name="enc3_conv") # Bottleneck
# Skip Attention Modules
# att. applied to the output of enc1 before concatenation
self.skip_attn1 = SpatialAttention(kernel_size=attention_kernel_size)
# att. applied to the output of enc2 before concatenation
self.skip_attn2 = SpatialAttention(kernel_size=attention_kernel_size)
# DECODER
# upsample bottleneck features (output of enc3_conv)
self.up2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2)
# decoder conv block takes upsampled + *refined* skip connection (output of skip_attn2)
self.dec2_conv = self.unet_block((features * 2) * 2, features * 2, name="dec2_conv")
# upsample dec2 features (output of dec2_conv)
self.up1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)
# decoder conv block takes upsampled + *refined* skip connection (output of skip_attn1)
self.dec1_conv = self.unet_block(features * 2, features, name="dec1_conv")
self.output_conv = nn.Conv2d(features, out_channels, kernel_size=1)
def forward(self, x):
# ENCODER
enc1_out = self.enc1_conv(x) # for skip connection 1
p1 = self.pool1(enc1_out)
enc2_out = self.enc2_conv(p1) # fo skip connection 2
p2 = self.pool2(enc2_out)
bottleneck = self.enc3_conv(p2)
# DECODER
# Decoder lvl2
up2_out = self.up2(bottleneck)
# apply Spatial Attention to Skip Connection 2
refined_enc2_out = self.skip_attn2(enc2_out)
# pad upsampled features to match refined skip connection size
diffY = refined_enc2_out.size()[2] - up2_out.size()[2]
diffX = refined_enc2_out.size()[3] - up2_out.size()[3]
padded_up2_out = F.pad(up2_out, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# concat. refined skip connection
cat2 = torch.cat([padded_up2_out, refined_enc2_out], dim=1)
dec2_out = self.dec2_conv(cat2)
# decoder lvl 1
up1_out = self.up1(dec2_out)
# Apply Spatial Attention to Skip Connection 1
refined_enc1_out = self.skip_attn1(enc1_out)
# pad upsampled features to match refined skip connection size
diffY = refined_enc1_out.size()[2] - up1_out.size()[2]
diffX = refined_enc1_out.size()[3] - up1_out.size()[3]
padded_up1_out = F.pad(up1_out, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# concat. refined skip connection
cat1 = torch.cat([padded_up1_out, refined_enc1_out], dim=1)
dec1_out = self.dec1_conv(cat1)
# final output
output = self.output_conv(dec1_out)
return torch.sigmoid(output)
@staticmethod
def unet_block(in_channels, features, name):
# typical U-Net block: Conv -> BN -> ReLU -> Conv -> BN -> ReLU
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(num_features=features),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(num_features=features),
nn.ReLU(inplace=True),
)
# exxample Instantiation & test
print("\n--- Running Basic Model Instantiation Test (U-Net with Attentive Skips) ---")
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device for test: {device}")
img_size_test = 128
batch_size_test = 4
init_features_test = 32
test_model_instance = UNetWithAttentiveSkips(
in_channels=3,
out_channels=3,
init_features=init_features_test,
attention_kernel_size=7 # Default kernel size for spatial attention
).to(device)
print(f"Model (UNetWithAttentiveSkips) instantiated successfully on {device}.")
# test w dummy input
dummy_input_test = torch.randn(batch_size_test, 3, img_size_test, img_size_test).to(device)
print(f"Dummy input shape: {dummy_input_test.shape}")
with torch.no_grad():
output_test = test_model_instance(dummy_input_test)
print(f"Dummy output shape: {output_test.shape}")
print("Forward pass completed successfully.")
total_params = sum(p.numel() for p in test_model_instance.parameters())
trainable_params = sum(p.numel() for p in test_model_instance.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
except Exception as e:
print(f"\nAn ERROR occurred during the instantiation test: {e}")
import traceback
traceback.print_exc()
print("--- End of Basic Model Instantiation Test (U-Net with Attentive Skips) ---\n")
lpips package found. pytorch-msssim package found. --- Running Basic Model Instantiation Test (U-Net with Attentive Skips) --- Using device for test: cuda Model (UNetWithAttentiveSkips) instantiated successfully on cuda. Dummy input shape: torch.Size([4, 3, 128, 128]) Dummy output shape: torch.Size([4, 3, 128, 128]) Forward pass completed successfully. Total parameters: 467,431 Trainable parameters: 467,431 --- End of Basic Model Instantiation Test (U-Net with Attentive Skips) ---